mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Cache: new Cache format in decoder-only models (#31421)
* draft bart with new cache * add cache for decoder-only models * revert utils * modify docstring * revert bart * minor fixes * fix copies (not related) * revert tests * remove enc-dec related code * remove bloom * remove opt (enc-dec) * update docstring * git, codegen, gpt_neo, gpt_neox, gpj * clean up * copied from statements * revert * tmp * update warning msg * forgot git * add more flags * run-slow git,codegen,gpt_neo,gpt_neox,gpj * add cache flag to VLMs * remove files * style * video LLMs also need a flag * style * llava will go in another PR * style * [run-slow] codegen, falcon, git, gpt_neo, gpt_neox, gptj, idefics * Update src/transformers/models/gpt_neo/modeling_gpt_neo.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * copy from * deprecate until v4.45 and warn if not training * nit * fix test * test static cache * add more tests and fix models * fix copies * return sliding window mask * run slow tests & fix + codestyle * one more falcon fix for alibi --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
parent
6af0854efa
commit
a30c865f99
@ -1016,7 +1016,9 @@ class StaticCache(Cache):
|
|||||||
|
|
||||||
self.dtype = dtype if dtype is not None else torch.float32
|
self.dtype = dtype if dtype is not None else torch.float32
|
||||||
self.num_key_value_heads = (
|
self.num_key_value_heads = (
|
||||||
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
|
config.num_attention_heads
|
||||||
|
if getattr(config, "num_key_value_heads", None) is None
|
||||||
|
else config.num_key_value_heads
|
||||||
)
|
)
|
||||||
|
|
||||||
self.key_cache: List[torch.Tensor] = []
|
self.key_cache: List[torch.Tensor] = []
|
||||||
|
@ -1473,7 +1473,7 @@ class GenerationMixin:
|
|||||||
# NOTE: self.dtype is not compatible with torch.compile, as it calls `self.parameters()`.
|
# NOTE: self.dtype is not compatible with torch.compile, as it calls `self.parameters()`.
|
||||||
# Workaround: trust the lm_head, whose attribute name is somewhat consistent across generative
|
# Workaround: trust the lm_head, whose attribute name is somewhat consistent across generative
|
||||||
# models. May cause trobles with non-text modalities.
|
# models. May cause trobles with non-text modalities.
|
||||||
cache_dtype = self.lm_head.weight.dtype
|
cache_dtype = self.get_output_embeddings().weight.dtype
|
||||||
|
|
||||||
cache_kwargs = {
|
cache_kwargs = {
|
||||||
"config": self.config,
|
"config": self.config,
|
||||||
|
@ -22,6 +22,8 @@ from torch import nn
|
|||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
|
from ...cache_utils import Cache, DynamicCache, StaticCache
|
||||||
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
||||||
@ -34,6 +36,60 @@ _CHECKPOINT_FOR_DOC = "Salesforce/codegen-2B-mono"
|
|||||||
_CONFIG_FOR_DOC = "CodeGenConfig"
|
_CONFIG_FOR_DOC = "CodeGenConfig"
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
|
||||||
|
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
sequence_length: int,
|
||||||
|
target_length: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
min_dtype: float,
|
||||||
|
cache_position: torch.Tensor,
|
||||||
|
batch_size: int,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||||
|
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
attention_mask (`torch.Tensor`):
|
||||||
|
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
||||||
|
sequence_length (`int`):
|
||||||
|
The sequence length being processed.
|
||||||
|
target_length (`int`):
|
||||||
|
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
|
||||||
|
dtype (`torch.dtype`):
|
||||||
|
The dtype to use for the 4D attention mask.
|
||||||
|
device (`torch.device`):
|
||||||
|
The device to plcae the 4D attention mask on.
|
||||||
|
min_dtype (`float`):
|
||||||
|
The minimum value representable with the dtype `dtype`.
|
||||||
|
cache_position (`torch.Tensor`):
|
||||||
|
Indices depicting the position of the input sequence tokens in the sequence.
|
||||||
|
batch_size (`torch.Tensor`):
|
||||||
|
Batch size.
|
||||||
|
"""
|
||||||
|
if attention_mask is not None and attention_mask.dim() == 4:
|
||||||
|
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||||
|
causal_mask = attention_mask
|
||||||
|
else:
|
||||||
|
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
||||||
|
if sequence_length != 1:
|
||||||
|
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||||
|
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||||
|
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||||
|
if attention_mask is not None:
|
||||||
|
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||||
|
mask_length = attention_mask.shape[-1]
|
||||||
|
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
||||||
|
padding_mask = padding_mask == 0
|
||||||
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||||
|
padding_mask, min_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
return causal_mask
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.gptj.modeling_gptj.create_sinusoidal_positions
|
# Copied from transformers.models.gptj.modeling_gptj.create_sinusoidal_positions
|
||||||
def create_sinusoidal_positions(num_pos: int, dim: int) -> torch.Tensor:
|
def create_sinusoidal_positions(num_pos: int, dim: int) -> torch.Tensor:
|
||||||
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64) / dim))
|
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64) / dim))
|
||||||
@ -57,20 +113,19 @@ def apply_rotary_pos_emb(tensor: torch.Tensor, sin: torch.Tensor, cos: torch.Ten
|
|||||||
|
|
||||||
|
|
||||||
class CodeGenAttention(nn.Module):
|
class CodeGenAttention(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config, layer_idx=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
max_positions = config.max_position_embeddings
|
max_positions = config.max_position_embeddings
|
||||||
self.register_buffer(
|
|
||||||
"causal_mask",
|
|
||||||
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
|
|
||||||
1, 1, max_positions, max_positions
|
|
||||||
),
|
|
||||||
persistent=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
||||||
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
if layer_idx is None:
|
||||||
|
logger.warning_once(
|
||||||
|
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
|
||||||
|
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
|
||||||
|
"when creating this class."
|
||||||
|
)
|
||||||
|
|
||||||
self.embed_dim = config.hidden_size
|
self.embed_dim = config.hidden_size
|
||||||
self.num_attention_heads = config.num_attention_heads
|
self.num_attention_heads = config.num_attention_heads
|
||||||
@ -114,27 +169,17 @@ class CodeGenAttention(nn.Module):
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
):
|
):
|
||||||
# compute causal mask from causal mask buffer
|
|
||||||
query_length, key_length = query.size(-2), key.size(-2)
|
|
||||||
causal_mask = self.causal_mask[:, :, key_length - query_length : key_length, :key_length]
|
|
||||||
|
|
||||||
# Keep the attention weights computation in fp32 to avoid overflow issues
|
# Keep the attention weights computation in fp32 to avoid overflow issues
|
||||||
query = query.to(torch.float32)
|
query = query.to(torch.float32)
|
||||||
key = key.to(torch.float32)
|
key = key.to(torch.float32)
|
||||||
|
|
||||||
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
||||||
|
|
||||||
attn_weights = attn_weights / self.scale_attn
|
|
||||||
mask_value = torch.finfo(attn_weights.dtype).min
|
|
||||||
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
|
||||||
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
|
|
||||||
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
|
|
||||||
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# Apply the attention mask
|
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||||
attn_weights = attn_weights + attention_mask
|
attn_weights += causal_mask
|
||||||
|
|
||||||
|
attn_weights = attn_weights / self.scale_attn
|
||||||
attn_weights = nn.Softmax(dim=-1)(attn_weights)
|
attn_weights = nn.Softmax(dim=-1)(attn_weights)
|
||||||
attn_weights = attn_weights.to(value.dtype)
|
attn_weights = attn_weights.to(value.dtype)
|
||||||
attn_weights = self.attn_dropout(attn_weights)
|
attn_weights = self.attn_dropout(attn_weights)
|
||||||
@ -150,12 +195,13 @@ class CodeGenAttention(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: Optional[torch.FloatTensor],
|
hidden_states: Optional[torch.FloatTensor],
|
||||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
layer_past: Optional[Cache] = None,
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
head_mask: Optional[torch.FloatTensor] = None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
) -> Union[
|
) -> Union[
|
||||||
Tuple[torch.Tensor, Tuple[torch.Tensor]],
|
Tuple[torch.Tensor, Tuple[torch.Tensor]],
|
||||||
Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]],
|
Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]],
|
||||||
@ -200,18 +246,16 @@ class CodeGenAttention(nn.Module):
|
|||||||
key = key.permute(0, 2, 1, 3)
|
key = key.permute(0, 2, 1, 3)
|
||||||
query = query.permute(0, 2, 1, 3)
|
query = query.permute(0, 2, 1, 3)
|
||||||
|
|
||||||
|
# Note that this cast is quite ugly, but is not implemented before ROPE as k_rot in the original codebase is always in fp32.
|
||||||
|
# Reference: https://github.com/salesforce/CodeGen/blob/f210c3bb1216c975ad858cd4132c0fdeabf4bfc2/codegen1/jaxformer/hf/codegen/modeling_codegen.py#L38
|
||||||
if layer_past is not None:
|
if layer_past is not None:
|
||||||
past_key = layer_past[0]
|
cache_kwargs = {
|
||||||
past_value = layer_past[1]
|
"sin": sin,
|
||||||
key = torch.cat((past_key, key), dim=-2)
|
"cos": cos,
|
||||||
value = torch.cat((past_value, value), dim=-2)
|
"partial_rotation_size": self.rotary_dim,
|
||||||
|
"cache_position": cache_position,
|
||||||
if use_cache is True:
|
}
|
||||||
# Note that this cast is quite ugly, but is not implemented before ROPE as k_rot in the original codebase is always in fp32.
|
key, value = layer_past.update(key.to(hidden_states.dtype), value, self.layer_idx, cache_kwargs)
|
||||||
# Reference: https://github.com/salesforce/CodeGen/blob/f210c3bb1216c975ad858cd4132c0fdeabf4bfc2/codegen1/jaxformer/hf/codegen/modeling_codegen.py#L38
|
|
||||||
present = (key.to(hidden_states.dtype), value)
|
|
||||||
else:
|
|
||||||
present = None
|
|
||||||
|
|
||||||
# compute self-attention: V x Softmax(QK^T)
|
# compute self-attention: V x Softmax(QK^T)
|
||||||
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
||||||
@ -220,7 +264,7 @@ class CodeGenAttention(nn.Module):
|
|||||||
attn_output = self.out_proj(attn_output)
|
attn_output = self.out_proj(attn_output)
|
||||||
attn_output = self.resid_dropout(attn_output)
|
attn_output = self.resid_dropout(attn_output)
|
||||||
|
|
||||||
outputs = (attn_output, present)
|
outputs = (attn_output, layer_past)
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
outputs += (attn_weights,)
|
outputs += (attn_weights,)
|
||||||
|
|
||||||
@ -250,22 +294,23 @@ class CodeGenMLP(nn.Module):
|
|||||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJBlock with GPTJ->CodeGen
|
# Copied from transformers.models.gptj.modeling_gptj.GPTJBlock with GPTJ->CodeGen
|
||||||
class CodeGenBlock(nn.Module):
|
class CodeGenBlock(nn.Module):
|
||||||
# Ignore copy
|
# Ignore copy
|
||||||
def __init__(self, config):
|
def __init__(self, config, layer_idx=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd
|
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd
|
||||||
self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
||||||
self.attn = CodeGenAttention(config)
|
self.attn = CodeGenAttention(config, layer_idx)
|
||||||
self.mlp = CodeGenMLP(inner_dim, config)
|
self.mlp = CodeGenMLP(inner_dim, config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: Optional[torch.FloatTensor],
|
hidden_states: Optional[torch.FloatTensor],
|
||||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
layer_past: Optional[Cache] = None,
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
head_mask: Optional[torch.FloatTensor] = None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
|
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.ln_1(hidden_states)
|
hidden_states = self.ln_1(hidden_states)
|
||||||
@ -277,6 +322,7 @@ class CodeGenBlock(nn.Module):
|
|||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
|
cache_position=cache_position,
|
||||||
)
|
)
|
||||||
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
|
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
|
||||||
outputs = attn_outputs[1:]
|
outputs = attn_outputs[1:]
|
||||||
@ -303,6 +349,9 @@ class CodeGenPreTrainedModel(PreTrainedModel):
|
|||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
_no_split_modules = ["CodeGenBlock"]
|
_no_split_modules = ["CodeGenBlock"]
|
||||||
_skip_keys_device_placement = "past_key_values"
|
_skip_keys_device_placement = "past_key_values"
|
||||||
|
_supports_cache_class = True
|
||||||
|
_supports_quantized_cache = True
|
||||||
|
_supports_static_cache = True
|
||||||
|
|
||||||
def __init__(self, *inputs, **kwargs):
|
def __init__(self, *inputs, **kwargs):
|
||||||
super().__init__(*inputs, **kwargs)
|
super().__init__(*inputs, **kwargs)
|
||||||
@ -374,6 +423,23 @@ CODEGEN_INPUTS_DOCSTRING = r"""
|
|||||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
||||||
is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
|
is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
|
||||||
model's internal embedding lookup matrix.
|
model's internal embedding lookup matrix.
|
||||||
|
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
|
||||||
|
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||||
|
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
||||||
|
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
||||||
|
|
||||||
|
Two formats are allowed:
|
||||||
|
- a [`~cache_utils.Cache`] instance;
|
||||||
|
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
||||||
|
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
|
||||||
|
cache format.
|
||||||
|
|
||||||
|
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
|
||||||
|
legacy cache format will be returned.
|
||||||
|
|
||||||
|
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
||||||
|
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
|
||||||
|
of shape `(batch_size, sequence_length)`.
|
||||||
output_attentions (`bool`, *optional*):
|
output_attentions (`bool`, *optional*):
|
||||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||||
tensors for more detail.
|
tensors for more detail.
|
||||||
@ -382,6 +448,10 @@ CODEGEN_INPUTS_DOCSTRING = r"""
|
|||||||
more detail.
|
more detail.
|
||||||
return_dict (`bool`, *optional*):
|
return_dict (`bool`, *optional*):
|
||||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
|
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||||
|
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
|
||||||
|
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
|
||||||
|
the complete sequence length.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -397,7 +467,7 @@ class CodeGenModel(CodeGenPreTrainedModel):
|
|||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
|
self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
|
||||||
self.drop = nn.Dropout(config.embd_pdrop)
|
self.drop = nn.Dropout(config.embd_pdrop)
|
||||||
self.h = nn.ModuleList([CodeGenBlock(config) for _ in range(config.n_layer)])
|
self.h = nn.ModuleList([CodeGenBlock(config, layer_idx=i) for i in range(config.n_layer)])
|
||||||
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||||
self.rotary_dim = min(config.rotary_dim, config.n_ctx // config.num_attention_heads)
|
self.rotary_dim = min(config.rotary_dim, config.n_ctx // config.num_attention_heads)
|
||||||
|
|
||||||
@ -421,7 +491,7 @@ class CodeGenModel(CodeGenPreTrainedModel):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor]]]] = None,
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
token_type_ids: Optional[torch.LongTensor] = None,
|
token_type_ids: Optional[torch.LongTensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
@ -431,6 +501,7 @@ class CodeGenModel(CodeGenPreTrainedModel):
|
|||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
@ -439,85 +510,62 @@ class CodeGenModel(CodeGenPreTrainedModel):
|
|||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
if input_ids is not None and inputs_embeds is not None:
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
raise ValueError(
|
||||||
elif input_ids is not None:
|
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
||||||
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
)
|
||||||
input_shape = input_ids.size()
|
|
||||||
input_ids = input_ids.view(-1, input_shape[-1])
|
|
||||||
batch_size = input_ids.shape[0]
|
|
||||||
elif inputs_embeds is not None:
|
|
||||||
input_shape = inputs_embeds.size()[:-1]
|
|
||||||
batch_size = inputs_embeds.shape[0]
|
|
||||||
else:
|
|
||||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
|
||||||
|
|
||||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
if self.gradient_checkpointing and self.training:
|
||||||
|
if use_cache:
|
||||||
|
logger.warning_once(
|
||||||
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||||
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
if token_type_ids is not None:
|
if inputs_embeds is None:
|
||||||
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
inputs_embeds = self.wte(input_ids)
|
||||||
|
|
||||||
if past_key_values is None:
|
use_legacy_cache = False
|
||||||
past_length = 0
|
if use_cache and not isinstance(past_key_values, Cache):
|
||||||
past_key_values = tuple([None] * len(self.h))
|
use_legacy_cache = True
|
||||||
else:
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
past_length = past_key_values[0][0].size(-2)
|
if not self.training:
|
||||||
|
logger.warning_once(
|
||||||
|
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.45. "
|
||||||
|
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
|
||||||
|
)
|
||||||
|
|
||||||
|
seq_length = inputs_embeds.shape[1]
|
||||||
|
if cache_position is None:
|
||||||
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
|
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=inputs_embeds.device)
|
||||||
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
position_ids = cache_position.unsqueeze(0)
|
||||||
position_ids = position_ids.unsqueeze(0)
|
|
||||||
|
|
||||||
# Attention mask.
|
causal_mask = self._update_causal_mask(
|
||||||
if attention_mask is not None:
|
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||||
if batch_size <= 0:
|
)
|
||||||
raise ValueError("batch_size has to be defined and > 0")
|
|
||||||
attention_mask = attention_mask.view(batch_size, -1)
|
|
||||||
# We create a 3D attention mask from a 2D tensor mask.
|
|
||||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
|
||||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
|
||||||
# this attention mask is more simple than the triangular masking of causal attention
|
|
||||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
|
||||||
attention_mask = attention_mask[:, None, None, :]
|
|
||||||
|
|
||||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
|
||||||
# masked positions, this operation will create a tensor which is 0.0 for
|
|
||||||
# positions we want to attend and the dtype's smallest value for masked positions.
|
|
||||||
# Since we are adding it to the raw scores before the softmax, this is
|
|
||||||
# effectively the same as removing these entirely.
|
|
||||||
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
|
||||||
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
|
|
||||||
|
|
||||||
# Prepare head mask if needed
|
# Prepare head mask if needed
|
||||||
# 1.0 in head_mask indicate we keep the head
|
# 1.0 in head_mask indicate we keep the head
|
||||||
# attention_probs has shape bsz x num_attention_heads x N x N
|
# attention_probs has shape bsz x num_attention_heads x N x N
|
||||||
# head_mask has shape n_layer x batch x num_attention_heads x N x N
|
# head_mask has shape n_layer x batch x num_attention_heads x N x N
|
||||||
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
||||||
|
|
||||||
if inputs_embeds is None:
|
|
||||||
inputs_embeds = self.wte(input_ids)
|
|
||||||
|
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
if token_type_ids is not None:
|
if token_type_ids is not None:
|
||||||
|
token_type_ids = token_type_ids.view(-1, seq_length)
|
||||||
token_type_embeds = self.wte(token_type_ids)
|
token_type_embeds = self.wte(token_type_ids)
|
||||||
hidden_states = hidden_states + token_type_embeds
|
hidden_states = hidden_states + token_type_embeds
|
||||||
|
|
||||||
hidden_states = self.drop(hidden_states)
|
hidden_states = self.drop(hidden_states)
|
||||||
|
output_shape = (-1, seq_length, hidden_states.size(-1))
|
||||||
|
|
||||||
output_shape = input_shape + (hidden_states.size(-1),)
|
next_decoder_cache = None
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
|
||||||
if use_cache:
|
|
||||||
logger.warning_once(
|
|
||||||
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
|
|
||||||
"`use_cache=False`..."
|
|
||||||
)
|
|
||||||
use_cache = False
|
|
||||||
|
|
||||||
presents = () if use_cache else None
|
|
||||||
all_self_attentions = () if output_attentions else None
|
all_self_attentions = () if output_attentions else None
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
for i, block in enumerate(self.h):
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
@ -526,26 +574,28 @@ class CodeGenModel(CodeGenPreTrainedModel):
|
|||||||
block.__call__,
|
block.__call__,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
None,
|
None,
|
||||||
attention_mask,
|
causal_mask,
|
||||||
position_ids,
|
position_ids,
|
||||||
head_mask[i],
|
head_mask[i],
|
||||||
use_cache,
|
use_cache,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
|
cache_position,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
outputs = block(
|
outputs = block(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
layer_past=layer_past,
|
layer_past=past_key_values,
|
||||||
attention_mask=attention_mask,
|
attention_mask=causal_mask,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
head_mask=head_mask[i],
|
head_mask=head_mask[i],
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
|
cache_position=cache_position,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
hidden_states = outputs[0]
|
||||||
if use_cache is True:
|
if use_cache is True:
|
||||||
presents = presents + (outputs[1],)
|
next_decoder_cache = outputs[1]
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
||||||
@ -557,16 +607,94 @@ class CodeGenModel(CodeGenPreTrainedModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
next_cache = None
|
||||||
|
if use_cache:
|
||||||
|
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
return tuple(
|
||||||
|
v for v in [hidden_states, next_cache, all_hidden_states, all_self_attentions] if v is not None
|
||||||
|
)
|
||||||
|
|
||||||
return BaseModelOutputWithPast(
|
return BaseModelOutputWithPast(
|
||||||
last_hidden_state=hidden_states,
|
last_hidden_state=hidden_states,
|
||||||
past_key_values=presents,
|
past_key_values=next_cache,
|
||||||
hidden_states=all_hidden_states,
|
hidden_states=all_hidden_states,
|
||||||
attentions=all_self_attentions,
|
attentions=all_self_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
|
||||||
|
def _update_causal_mask(
|
||||||
|
self,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
input_tensor: torch.Tensor,
|
||||||
|
cache_position: torch.Tensor,
|
||||||
|
past_key_values: Cache,
|
||||||
|
output_attentions: bool,
|
||||||
|
):
|
||||||
|
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
|
||||||
|
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
|
||||||
|
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
|
||||||
|
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
|
||||||
|
|
||||||
|
if self.config._attn_implementation == "flash_attention_2":
|
||||||
|
if attention_mask is not None and 0.0 in attention_mask:
|
||||||
|
return attention_mask
|
||||||
|
return None
|
||||||
|
|
||||||
|
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||||
|
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||||
|
# to infer the attention mask.
|
||||||
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
|
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||||
|
|
||||||
|
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||||
|
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
|
||||||
|
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||||
|
attention_mask,
|
||||||
|
inputs_embeds=input_tensor,
|
||||||
|
past_key_values_length=past_seen_tokens,
|
||||||
|
is_training=self.training,
|
||||||
|
):
|
||||||
|
return None
|
||||||
|
|
||||||
|
dtype, device = input_tensor.dtype, input_tensor.device
|
||||||
|
min_dtype = torch.finfo(dtype).min
|
||||||
|
sequence_length = input_tensor.shape[1]
|
||||||
|
if using_static_cache:
|
||||||
|
target_length = past_key_values.get_max_length()
|
||||||
|
else:
|
||||||
|
target_length = (
|
||||||
|
attention_mask.shape[-1]
|
||||||
|
if isinstance(attention_mask, torch.Tensor)
|
||||||
|
else past_seen_tokens + sequence_length + 1
|
||||||
|
)
|
||||||
|
|
||||||
|
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||||
|
causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
|
attention_mask,
|
||||||
|
sequence_length=sequence_length,
|
||||||
|
target_length=target_length,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
min_dtype=min_dtype,
|
||||||
|
cache_position=cache_position,
|
||||||
|
batch_size=input_tensor.shape[0],
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.config._attn_implementation == "sdpa"
|
||||||
|
and attention_mask is not None
|
||||||
|
and attention_mask.device.type == "cuda"
|
||||||
|
and not output_attentions
|
||||||
|
):
|
||||||
|
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||||
|
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||||
|
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||||
|
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||||
|
|
||||||
|
return causal_mask
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"""
|
"""
|
||||||
@ -591,26 +719,31 @@ class CodeGenForCausalLM(CodeGenPreTrainedModel):
|
|||||||
def set_output_embeddings(self, new_embeddings):
|
def set_output_embeddings(self, new_embeddings):
|
||||||
self.lm_head = new_embeddings
|
self.lm_head = new_embeddings
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, inputs_embeds=None, past_key_values=None, **kwargs):
|
# Copied from transformers.models.gpt_neo.modeling_gpt_neo.GPTNeoForCausalLM.prepare_inputs_for_generation
|
||||||
token_type_ids = kwargs.get("token_type_ids", None)
|
def prepare_inputs_for_generation(
|
||||||
# Omit tokens covered by past_key_values
|
self,
|
||||||
if past_key_values:
|
input_ids,
|
||||||
past_length = past_key_values[0][0].shape[2]
|
attention_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
past_key_values=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
cache_position=None,
|
||||||
|
use_cache=True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||||
|
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
||||||
|
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||||
|
if past_key_values is not None:
|
||||||
|
if inputs_embeds is not None: # Exception 1
|
||||||
|
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||||
|
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||||
|
input_ids = input_ids[:, cache_position]
|
||||||
|
|
||||||
# Some generation methods already pass only the last input ID
|
|
||||||
if input_ids.shape[1] > past_length:
|
|
||||||
remove_prefix_length = past_length
|
|
||||||
else:
|
|
||||||
# Default to old behavior: keep only final ID
|
|
||||||
remove_prefix_length = input_ids.shape[1] - 1
|
|
||||||
|
|
||||||
input_ids = input_ids[:, remove_prefix_length:]
|
|
||||||
if token_type_ids is not None:
|
if token_type_ids is not None:
|
||||||
token_type_ids = token_type_ids[:, -input_ids.shape[1] :]
|
token_type_ids = token_type_ids[:, -input_ids.shape[1] :]
|
||||||
|
|
||||||
attention_mask = kwargs.get("attention_mask", None)
|
|
||||||
position_ids = kwargs.get("position_ids", None)
|
|
||||||
|
|
||||||
if attention_mask is not None and position_ids is None:
|
if attention_mask is not None and position_ids is None:
|
||||||
# create position_ids on the fly for batch generation
|
# create position_ids on the fly for batch generation
|
||||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||||
@ -618,19 +751,45 @@ class CodeGenForCausalLM(CodeGenPreTrainedModel):
|
|||||||
if past_key_values:
|
if past_key_values:
|
||||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||||
|
|
||||||
|
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
|
||||||
|
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
|
||||||
|
|
||||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||||
if inputs_embeds is not None and past_key_values is None:
|
if inputs_embeds is not None and cache_position[0] == 0:
|
||||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||||
else:
|
else:
|
||||||
model_inputs = {"input_ids": input_ids.contiguous()}
|
model_inputs = {"input_ids": input_ids}
|
||||||
|
|
||||||
|
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
|
||||||
|
if inputs_embeds is not None:
|
||||||
|
batch_size, sequence_length = inputs_embeds.shape
|
||||||
|
device = inputs_embeds.device
|
||||||
|
else:
|
||||||
|
batch_size, sequence_length = input_ids.shape
|
||||||
|
device = input_ids.device
|
||||||
|
|
||||||
|
dtype = self.lm_head.weight.dtype
|
||||||
|
min_dtype = torch.finfo(dtype).min
|
||||||
|
|
||||||
|
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
|
attention_mask,
|
||||||
|
sequence_length=sequence_length,
|
||||||
|
target_length=past_key_values.get_max_length(),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
min_dtype=min_dtype,
|
||||||
|
cache_position=cache_position,
|
||||||
|
batch_size=batch_size,
|
||||||
|
)
|
||||||
|
|
||||||
model_inputs.update(
|
model_inputs.update(
|
||||||
{
|
{
|
||||||
"past_key_values": past_key_values,
|
|
||||||
"use_cache": kwargs.get("use_cache"),
|
|
||||||
"position_ids": position_ids,
|
"position_ids": position_ids,
|
||||||
"attention_mask": attention_mask,
|
"cache_position": cache_position,
|
||||||
|
"past_key_values": past_key_values,
|
||||||
|
"use_cache": use_cache,
|
||||||
"token_type_ids": token_type_ids,
|
"token_type_ids": token_type_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return model_inputs
|
return model_inputs
|
||||||
@ -644,7 +803,7 @@ class CodeGenForCausalLM(CodeGenPreTrainedModel):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor]]]] = None,
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
token_type_ids: Optional[torch.LongTensor] = None,
|
token_type_ids: Optional[torch.LongTensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
@ -655,6 +814,7 @@ class CodeGenForCausalLM(CodeGenPreTrainedModel):
|
|||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
r"""
|
r"""
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
@ -676,6 +836,7 @@ class CodeGenForCausalLM(CodeGenPreTrainedModel):
|
|||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
)
|
)
|
||||||
hidden_states = transformer_outputs[0]
|
hidden_states = transformer_outputs[0]
|
||||||
|
|
||||||
|
@ -24,10 +24,9 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
|
|||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from ...activations import get_activation
|
from ...activations import get_activation
|
||||||
|
from ...cache_utils import Cache, DynamicCache, StaticCache
|
||||||
from ...modeling_attn_mask_utils import (
|
from ...modeling_attn_mask_utils import (
|
||||||
AttentionMaskConverter,
|
AttentionMaskConverter,
|
||||||
_prepare_4d_causal_attention_mask,
|
|
||||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
|
||||||
)
|
)
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutputWithPastAndCrossAttentions,
|
BaseModelOutputWithPastAndCrossAttentions,
|
||||||
@ -62,6 +61,60 @@ _CHECKPOINT_FOR_DOC = "Rocketknight1/falcon-rw-1b"
|
|||||||
_CONFIG_FOR_DOC = "FalconConfig"
|
_CONFIG_FOR_DOC = "FalconConfig"
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
|
||||||
|
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
sequence_length: int,
|
||||||
|
target_length: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
min_dtype: float,
|
||||||
|
cache_position: torch.Tensor,
|
||||||
|
batch_size: int,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||||
|
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
attention_mask (`torch.Tensor`):
|
||||||
|
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
||||||
|
sequence_length (`int`):
|
||||||
|
The sequence length being processed.
|
||||||
|
target_length (`int`):
|
||||||
|
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
|
||||||
|
dtype (`torch.dtype`):
|
||||||
|
The dtype to use for the 4D attention mask.
|
||||||
|
device (`torch.device`):
|
||||||
|
The device to plcae the 4D attention mask on.
|
||||||
|
min_dtype (`float`):
|
||||||
|
The minimum value representable with the dtype `dtype`.
|
||||||
|
cache_position (`torch.Tensor`):
|
||||||
|
Indices depicting the position of the input sequence tokens in the sequence.
|
||||||
|
batch_size (`torch.Tensor`):
|
||||||
|
Batch size.
|
||||||
|
"""
|
||||||
|
if attention_mask is not None and attention_mask.dim() == 4:
|
||||||
|
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||||
|
causal_mask = attention_mask
|
||||||
|
else:
|
||||||
|
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
||||||
|
if sequence_length != 1:
|
||||||
|
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||||
|
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||||
|
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||||
|
if attention_mask is not None:
|
||||||
|
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||||
|
mask_length = attention_mask.shape[-1]
|
||||||
|
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
||||||
|
padding_mask = padding_mask == 0
|
||||||
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||||
|
padding_mask, min_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
return causal_mask
|
||||||
|
|
||||||
|
|
||||||
# NOTE(Hesslow): Unfortunately we did not fuse matmul and bias during training, this means that there's one additional quantization to bfloat16 between the operations.
|
# NOTE(Hesslow): Unfortunately we did not fuse matmul and bias during training, this means that there's one additional quantization to bfloat16 between the operations.
|
||||||
# In order not to degrade the quality of our HF-port, we keep these characteristics in the final model.
|
# In order not to degrade the quality of our HF-port, we keep these characteristics in the final model.
|
||||||
class FalconLinear(nn.Linear):
|
class FalconLinear(nn.Linear):
|
||||||
@ -244,7 +297,7 @@ def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training:
|
|||||||
|
|
||||||
|
|
||||||
class FalconAttention(nn.Module):
|
class FalconAttention(nn.Module):
|
||||||
def __init__(self, config: FalconConfig):
|
def __init__(self, config: FalconConfig, layer_idx=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
@ -257,6 +310,13 @@ class FalconAttention(nn.Module):
|
|||||||
self.rope_theta = config.rope_theta
|
self.rope_theta = config.rope_theta
|
||||||
self.is_causal = True
|
self.is_causal = True
|
||||||
self._use_sdpa = config._attn_implementation == "sdpa"
|
self._use_sdpa = config._attn_implementation == "sdpa"
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
if layer_idx is None:
|
||||||
|
logger.warning_once(
|
||||||
|
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
|
||||||
|
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
|
||||||
|
"when creating this class."
|
||||||
|
)
|
||||||
|
|
||||||
if self.head_dim * self.num_heads != self.hidden_size:
|
if self.head_dim * self.num_heads != self.hidden_size:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -373,10 +433,11 @@ class FalconAttention(nn.Module):
|
|||||||
alibi: Optional[torch.Tensor],
|
alibi: Optional[torch.Tensor],
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
layer_past: Optional[Cache] = None,
|
||||||
head_mask: Optional[torch.Tensor] = None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
):
|
):
|
||||||
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
|
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
|
||||||
num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads
|
num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads
|
||||||
@ -391,25 +452,24 @@ class FalconAttention(nn.Module):
|
|||||||
|
|
||||||
kv_seq_len = key_layer.shape[-2]
|
kv_seq_len = key_layer.shape[-2]
|
||||||
if layer_past is not None:
|
if layer_past is not None:
|
||||||
kv_seq_len += layer_past[0].shape[-2]
|
if self.layer_idx is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||||
|
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||||
|
"with a layer index."
|
||||||
|
)
|
||||||
|
kv_seq_len += layer_past.get_seq_length(self.layer_idx)
|
||||||
if alibi is None:
|
if alibi is None:
|
||||||
cos, sin = self.rotary_emb(value_layer, seq_len=kv_seq_len)
|
cos, sin = self.rotary_emb(value_layer, seq_len=kv_seq_len)
|
||||||
query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin, position_ids)
|
query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin, position_ids)
|
||||||
|
|
||||||
if layer_past is not None:
|
if layer_past is not None:
|
||||||
past_key, past_value = layer_past
|
cache_kwargs = {"cache_position": cache_position}
|
||||||
# concatenate along seq_length dimension:
|
if alibi is None:
|
||||||
# - key: [batch_size, self.num_heads, kv_length, head_dim]
|
cache_kwargs.update({"sin": sin, "cos": cos})
|
||||||
# - value: [batch_size, self.num_heads, kv_length, head_dim]
|
key_layer, value_layer = layer_past.update(key_layer, value_layer, self.layer_idx, cache_kwargs)
|
||||||
key_layer = torch.cat((past_key, key_layer), dim=-2)
|
|
||||||
value_layer = torch.cat((past_value, value_layer), dim=-2)
|
|
||||||
|
|
||||||
kv_length = key_layer.shape[-2]
|
kv_length = key_layer.shape[-2]
|
||||||
if use_cache:
|
|
||||||
present = (key_layer, value_layer)
|
|
||||||
else:
|
|
||||||
present = None
|
|
||||||
|
|
||||||
if self._use_sdpa and query_layer.device.type == "cuda" and attention_mask is not None:
|
if self._use_sdpa and query_layer.device.type == "cuda" and attention_mask is not None:
|
||||||
# For torch<=2.1.2, SDPA with memory-efficient backend is bugged with non-contiguous inputs with custom attn_mask,
|
# For torch<=2.1.2, SDPA with memory-efficient backend is bugged with non-contiguous inputs with custom attn_mask,
|
||||||
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
||||||
@ -417,6 +477,9 @@ class FalconAttention(nn.Module):
|
|||||||
key_layer = key_layer.contiguous()
|
key_layer = key_layer.contiguous()
|
||||||
value_layer = value_layer.contiguous()
|
value_layer = value_layer.contiguous()
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = attention_mask[:, :, :, : key_layer.shape[-2]]
|
||||||
|
|
||||||
if alibi is None:
|
if alibi is None:
|
||||||
if self._use_sdpa and not output_attentions:
|
if self._use_sdpa and not output_attentions:
|
||||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this if statement instead of an
|
# We dispatch to SDPA's Flash Attention or Efficient kernels via this if statement instead of an
|
||||||
@ -448,9 +511,9 @@ class FalconAttention(nn.Module):
|
|||||||
attn_output = self.dense(attn_output)
|
attn_output = self.dense(attn_output)
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
return attn_output, present, attention_scores
|
return attn_output, layer_past, attention_scores
|
||||||
else:
|
else:
|
||||||
return attn_output, present
|
return attn_output, layer_past
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if self._use_sdpa and not output_attentions and head_mask is None:
|
if self._use_sdpa and not output_attentions and head_mask is None:
|
||||||
@ -502,9 +565,9 @@ class FalconAttention(nn.Module):
|
|||||||
attn_output = self.dense(attn_output)
|
attn_output = self.dense(attn_output)
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
return attn_output, present, attention_probs
|
return attn_output, layer_past, attention_probs
|
||||||
else:
|
else:
|
||||||
return attn_output, present
|
return attn_output, layer_past
|
||||||
|
|
||||||
|
|
||||||
class FalconFlashAttention2(FalconAttention):
|
class FalconFlashAttention2(FalconAttention):
|
||||||
@ -529,10 +592,11 @@ class FalconFlashAttention2(FalconAttention):
|
|||||||
alibi: Optional[torch.Tensor],
|
alibi: Optional[torch.Tensor],
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
layer_past: Optional[Cache] = None,
|
||||||
head_mask: Optional[torch.Tensor] = None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
):
|
):
|
||||||
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
|
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
|
||||||
num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads
|
num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads
|
||||||
@ -547,20 +611,22 @@ class FalconFlashAttention2(FalconAttention):
|
|||||||
|
|
||||||
kv_seq_len = key_layer.shape[-2]
|
kv_seq_len = key_layer.shape[-2]
|
||||||
if layer_past is not None:
|
if layer_past is not None:
|
||||||
kv_seq_len += layer_past[0].shape[-2]
|
if self.layer_idx is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||||
|
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||||
|
"with a layer index."
|
||||||
|
)
|
||||||
|
kv_seq_len += layer_past.get_seq_length(self.layer_idx)
|
||||||
if alibi is None:
|
if alibi is None:
|
||||||
cos, sin = self.rotary_emb(value_layer, seq_len=kv_seq_len)
|
cos, sin = self.rotary_emb(value_layer, seq_len=kv_seq_len)
|
||||||
query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin, position_ids)
|
query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin, position_ids)
|
||||||
|
|
||||||
if layer_past is not None and use_cache:
|
if layer_past is not None:
|
||||||
past_key, past_value = layer_past
|
cache_kwargs = {"cache_position": cache_position}
|
||||||
# concatenate along seq_length dimension:
|
if alibi is None:
|
||||||
# - key: [batch_size, self.num_heads, kv_length, head_dim]
|
cache_kwargs.update({"sin": sin, "cos": cos})
|
||||||
# - value: [batch_size, self.num_heads, kv_length, head_dim]
|
key_layer, value_layer = layer_past.update(key_layer, value_layer, self.layer_idx, cache_kwargs)
|
||||||
key_layer = torch.cat((past_key, key_layer), dim=-2)
|
|
||||||
value_layer = torch.cat((past_value, value_layer), dim=-2)
|
|
||||||
|
|
||||||
past_key_value = (key_layer, value_layer) if use_cache else None
|
|
||||||
|
|
||||||
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
||||||
# to be able to avoid many of these transpose/reshape/view.
|
# to be able to avoid many of these transpose/reshape/view.
|
||||||
@ -614,7 +680,7 @@ class FalconFlashAttention2(FalconAttention):
|
|||||||
if not output_attentions:
|
if not output_attentions:
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
|
|
||||||
return attn_output, past_key_value, attn_weights
|
return attn_output, layer_past, attn_weights
|
||||||
|
|
||||||
|
|
||||||
class FalconMLP(nn.Module):
|
class FalconMLP(nn.Module):
|
||||||
@ -641,12 +707,12 @@ FALCON_ATTENTION_CLASSES = {
|
|||||||
|
|
||||||
|
|
||||||
class FalconDecoderLayer(nn.Module):
|
class FalconDecoderLayer(nn.Module):
|
||||||
def __init__(self, config: FalconConfig):
|
def __init__(self, config: FalconConfig, layer_idx=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
hidden_size = config.hidden_size
|
hidden_size = config.hidden_size
|
||||||
self.num_heads = config.num_attention_heads
|
self.num_heads = config.num_attention_heads
|
||||||
|
|
||||||
self.self_attention = FALCON_ATTENTION_CLASSES[config._attn_implementation](config)
|
self.self_attention = FALCON_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
|
||||||
self.mlp = FalconMLP(config)
|
self.mlp = FalconMLP(config)
|
||||||
self.hidden_dropout = config.hidden_dropout
|
self.hidden_dropout = config.hidden_dropout
|
||||||
self.config = config
|
self.config = config
|
||||||
@ -672,10 +738,11 @@ class FalconDecoderLayer(nn.Module):
|
|||||||
alibi: Optional[torch.Tensor],
|
alibi: Optional[torch.Tensor],
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
layer_past: Optional[Union[Cache, Tuple[torch.Tensor, torch.Tensor]]] = None,
|
||||||
head_mask: Optional[torch.Tensor] = None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
@ -696,6 +763,7 @@ class FalconDecoderLayer(nn.Module):
|
|||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
|
cache_position=cache_position,
|
||||||
)
|
)
|
||||||
|
|
||||||
attention_output = attn_outputs[0]
|
attention_output = attn_outputs[0]
|
||||||
@ -731,7 +799,7 @@ class FalconDecoderLayer(nn.Module):
|
|||||||
else:
|
else:
|
||||||
outputs = (output,) + outputs[1:]
|
outputs = (output,) + outputs[1:]
|
||||||
|
|
||||||
return outputs # hidden_states, present, attentions
|
return outputs # hidden_states, past_kv, attentions
|
||||||
|
|
||||||
|
|
||||||
FALCON_START_DOCSTRING = r"""
|
FALCON_START_DOCSTRING = r"""
|
||||||
@ -762,14 +830,23 @@ FALCON_INPUTS_DOCSTRING = r"""
|
|||||||
[`PreTrainedTokenizer.__call__`] for details.
|
[`PreTrainedTokenizer.__call__`] for details.
|
||||||
|
|
||||||
[What are input IDs?](../glossary#input-ids)
|
[What are input IDs?](../glossary#input-ids)
|
||||||
past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.num_hidden_layers`):
|
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
|
||||||
Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
|
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||||
`past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
|
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
||||||
their past given to this model should not be passed as `input_ids` as they have already been computed.
|
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
||||||
|
|
||||||
Each element of `past_key_values` is a tuple (past_key, past_value):
|
Two formats are allowed:
|
||||||
- past_key: [batch_size * num_heads, head_dim, kv_length]
|
- a [`~cache_utils.Cache`] instance;
|
||||||
- past_value: [batch_size * num_heads, kv_length, head_dim]
|
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
||||||
|
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
|
||||||
|
cache format.
|
||||||
|
|
||||||
|
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
|
||||||
|
legacy cache format will be returned.
|
||||||
|
|
||||||
|
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
||||||
|
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
|
||||||
|
of shape `(batch_size, sequence_length)`.
|
||||||
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||||
|
|
||||||
@ -806,6 +883,10 @@ FALCON_INPUTS_DOCSTRING = r"""
|
|||||||
more detail.
|
more detail.
|
||||||
return_dict (`bool`, *optional*):
|
return_dict (`bool`, *optional*):
|
||||||
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
|
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
|
||||||
|
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||||
|
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
|
||||||
|
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
|
||||||
|
the complete sequence length.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -821,6 +902,9 @@ class FalconPreTrainedModel(PreTrainedModel):
|
|||||||
_no_split_modules = ["FalconDecoderLayer"]
|
_no_split_modules = ["FalconDecoderLayer"]
|
||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
_supports_sdpa = True
|
_supports_sdpa = True
|
||||||
|
_supports_cache_class = True
|
||||||
|
_supports_quantized_cache = True
|
||||||
|
_supports_static_cache = True
|
||||||
|
|
||||||
def __init__(self, *inputs, **kwargs):
|
def __init__(self, *inputs, **kwargs):
|
||||||
super().__init__(*inputs, **kwargs)
|
super().__init__(*inputs, **kwargs)
|
||||||
@ -877,7 +961,7 @@ class FalconModel(FalconPreTrainedModel):
|
|||||||
self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
|
self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
|
||||||
|
|
||||||
# Transformer blocks
|
# Transformer blocks
|
||||||
self.h = nn.ModuleList([FalconDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
self.h = nn.ModuleList([FalconDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
|
||||||
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
||||||
self._use_sdpa = config._attn_implementation == "sdpa"
|
self._use_sdpa = config._attn_implementation == "sdpa"
|
||||||
|
|
||||||
@ -904,7 +988,7 @@ class FalconModel(FalconPreTrainedModel):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
head_mask: Optional[torch.LongTensor] = None,
|
head_mask: Optional[torch.LongTensor] = None,
|
||||||
@ -913,6 +997,7 @@ class FalconModel(FalconPreTrainedModel):
|
|||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
|
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
@ -921,38 +1006,35 @@ class FalconModel(FalconPreTrainedModel):
|
|||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
if input_ids is not None and inputs_embeds is not None:
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
raise ValueError(
|
||||||
elif input_ids is not None:
|
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
||||||
batch_size, seq_length = input_ids.shape
|
)
|
||||||
elif inputs_embeds is not None:
|
|
||||||
batch_size, seq_length, _ = inputs_embeds.shape
|
|
||||||
else:
|
|
||||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
|
||||||
|
|
||||||
if past_key_values is None:
|
if self.gradient_checkpointing and self.training:
|
||||||
past_key_values = tuple([None] * len(self.h))
|
if use_cache:
|
||||||
|
logger.warning_once(
|
||||||
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||||
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.word_embeddings(input_ids)
|
inputs_embeds = self.word_embeddings(input_ids)
|
||||||
|
|
||||||
hidden_states = inputs_embeds
|
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
|
||||||
if use_cache:
|
|
||||||
logger.warning(
|
|
||||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
|
||||||
)
|
|
||||||
use_cache = False
|
|
||||||
presents = () if use_cache else None
|
|
||||||
all_self_attentions = () if output_attentions else None
|
|
||||||
all_hidden_states = () if output_hidden_states else None
|
|
||||||
|
|
||||||
# Compute alibi tensor: check build_alibi_tensor documentation
|
# Compute alibi tensor: check build_alibi_tensor documentation
|
||||||
past_key_values_length = 0
|
use_legacy_cache = False
|
||||||
if past_key_values[0] is not None:
|
if use_cache and not isinstance(past_key_values, Cache):
|
||||||
past_key_values_length = past_key_values[0][0].shape[-2]
|
use_legacy_cache = True
|
||||||
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
|
if not self.training:
|
||||||
|
logger.warning_once(
|
||||||
|
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.45. "
|
||||||
|
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
|
||||||
|
)
|
||||||
|
|
||||||
|
alibi = None
|
||||||
|
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
|
batch_size, seq_length, _ = inputs_embeds.shape
|
||||||
if self.use_alibi:
|
if self.use_alibi:
|
||||||
mask = (
|
mask = (
|
||||||
torch.ones(
|
torch.ones(
|
||||||
@ -961,67 +1043,32 @@ class FalconModel(FalconPreTrainedModel):
|
|||||||
if attention_mask is None
|
if attention_mask is None
|
||||||
else attention_mask
|
else attention_mask
|
||||||
)
|
)
|
||||||
alibi = build_alibi_tensor(mask, self.num_heads, dtype=hidden_states.dtype)
|
alibi = build_alibi_tensor(mask, self.num_heads, dtype=inputs_embeds.dtype)
|
||||||
else:
|
|
||||||
alibi = None
|
|
||||||
if position_ids is None:
|
|
||||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
|
||||||
position_ids = torch.arange(
|
|
||||||
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
|
||||||
)
|
|
||||||
position_ids = position_ids.unsqueeze(0)
|
|
||||||
|
|
||||||
if self._use_flash_attention_2:
|
if cache_position is None:
|
||||||
# 2d mask is passed through the layers
|
cache_position = torch.arange(
|
||||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
|
||||||
elif self._use_sdpa and not output_attentions:
|
|
||||||
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
|
||||||
# the manual implementation that requires a 4D causal mask in all cases.
|
|
||||||
if alibi is None:
|
|
||||||
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
|
||||||
attention_mask,
|
|
||||||
(batch_size, seq_length),
|
|
||||||
inputs_embeds,
|
|
||||||
past_key_values_length,
|
|
||||||
)
|
|
||||||
elif head_mask is None:
|
|
||||||
alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:])
|
|
||||||
|
|
||||||
# We don't call _prepare_4d_causal_attention_mask_for_sdpa as we need to mask alibi using the 4D attention_mask untouched.
|
|
||||||
attention_mask = _prepare_4d_causal_attention_mask(
|
|
||||||
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
|
||||||
)
|
|
||||||
|
|
||||||
# We take care to integrate alibi bias in the attention_mask here.
|
|
||||||
min_dtype = torch.finfo(alibi.dtype).min
|
|
||||||
attention_mask = torch.masked_fill(
|
|
||||||
alibi / math.sqrt(self.config.hidden_size // self.num_heads),
|
|
||||||
attention_mask < -1,
|
|
||||||
min_dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
# From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
|
|
||||||
# produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
|
|
||||||
if seq_length > 1 and attention_mask.device.type == "cuda":
|
|
||||||
attention_mask = AttentionMaskConverter._unmask_unattended(attention_mask, min_dtype=min_dtype)
|
|
||||||
else:
|
|
||||||
# PyTorch SDPA does not support head_mask, we fall back on the eager implementation in this case.
|
|
||||||
attention_mask = _prepare_4d_causal_attention_mask(
|
|
||||||
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# 4d mask is passed through the layers
|
|
||||||
attention_mask = _prepare_4d_causal_attention_mask(
|
|
||||||
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = cache_position.unsqueeze(0)
|
||||||
|
|
||||||
|
causal_mask = self._update_causal_mask(
|
||||||
|
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions, head_mask, alibi
|
||||||
|
)
|
||||||
|
|
||||||
# Prepare head mask if needed
|
# Prepare head mask if needed
|
||||||
# 1.0 in head_mask indicate we keep the head
|
# 1.0 in head_mask indicate we keep the head
|
||||||
# attention_probs has shape batch_size x num_heads x N x N
|
# attention_probs has shape batch_size x num_heads x N x N
|
||||||
# head_mask has shape n_layer x batch x num_heads x N x N
|
# head_mask has shape n_layer x batch x num_heads x N x N
|
||||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
next_decoder_cache = None
|
||||||
|
all_self_attentions = () if output_attentions else None
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
|
||||||
|
for i, block in enumerate(self.h):
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
@ -1030,28 +1077,30 @@ class FalconModel(FalconPreTrainedModel):
|
|||||||
block.__call__,
|
block.__call__,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
alibi,
|
alibi,
|
||||||
attention_mask,
|
causal_mask,
|
||||||
position_ids,
|
position_ids,
|
||||||
head_mask[i],
|
head_mask[i],
|
||||||
layer_past,
|
past_key_values,
|
||||||
use_cache,
|
use_cache,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
|
cache_position,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
outputs = block(
|
outputs = block(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
layer_past=layer_past,
|
layer_past=past_key_values,
|
||||||
attention_mask=attention_mask,
|
attention_mask=causal_mask,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
head_mask=head_mask[i],
|
head_mask=head_mask[i],
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
alibi=alibi,
|
alibi=alibi,
|
||||||
|
cache_position=cache_position,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
hidden_states = outputs[0]
|
||||||
if use_cache is True:
|
if use_cache is True:
|
||||||
presents = presents + (outputs[1],)
|
next_decoder_cache = outputs[1]
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
||||||
@ -1062,16 +1111,110 @@ class FalconModel(FalconPreTrainedModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
next_cache = None
|
||||||
|
if use_cache:
|
||||||
|
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
return tuple(
|
||||||
|
v for v in [hidden_states, next_cache, all_hidden_states, all_self_attentions] if v is not None
|
||||||
|
)
|
||||||
|
|
||||||
return BaseModelOutputWithPastAndCrossAttentions(
|
return BaseModelOutputWithPastAndCrossAttentions(
|
||||||
last_hidden_state=hidden_states,
|
last_hidden_state=hidden_states,
|
||||||
past_key_values=presents,
|
past_key_values=next_cache,
|
||||||
hidden_states=all_hidden_states,
|
hidden_states=all_hidden_states,
|
||||||
attentions=all_self_attentions,
|
attentions=all_self_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _update_causal_mask(
|
||||||
|
self,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
input_tensor: torch.Tensor,
|
||||||
|
cache_position: torch.Tensor,
|
||||||
|
past_key_values: Cache,
|
||||||
|
output_attentions: bool,
|
||||||
|
head_mask: torch.Tensor,
|
||||||
|
alibi: torch.Tensor,
|
||||||
|
):
|
||||||
|
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
|
||||||
|
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
|
||||||
|
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
|
||||||
|
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
|
||||||
|
|
||||||
|
if self.config._attn_implementation == "flash_attention_2":
|
||||||
|
if attention_mask is not None and 0.0 in attention_mask:
|
||||||
|
return attention_mask
|
||||||
|
return None
|
||||||
|
|
||||||
|
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||||
|
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||||
|
# to infer the attention mask.
|
||||||
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
|
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||||
|
|
||||||
|
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||||
|
if (
|
||||||
|
self.config._attn_implementation == "sdpa"
|
||||||
|
and not using_static_cache
|
||||||
|
and not output_attentions
|
||||||
|
and head_mask is None
|
||||||
|
and alibi is None
|
||||||
|
):
|
||||||
|
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||||
|
attention_mask,
|
||||||
|
inputs_embeds=input_tensor,
|
||||||
|
past_key_values_length=past_seen_tokens,
|
||||||
|
is_training=self.training,
|
||||||
|
):
|
||||||
|
return None
|
||||||
|
|
||||||
|
dtype, device = input_tensor.dtype, input_tensor.device
|
||||||
|
min_dtype = torch.finfo(dtype).min
|
||||||
|
batch_size, sequence_length, _ = input_tensor.shape
|
||||||
|
if using_static_cache:
|
||||||
|
target_length = past_key_values.get_max_length()
|
||||||
|
else:
|
||||||
|
target_length = (
|
||||||
|
attention_mask.shape[-1]
|
||||||
|
if isinstance(attention_mask, torch.Tensor)
|
||||||
|
else past_seen_tokens + sequence_length
|
||||||
|
)
|
||||||
|
|
||||||
|
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||||
|
causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
|
attention_mask,
|
||||||
|
sequence_length=sequence_length,
|
||||||
|
target_length=target_length,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
min_dtype=min_dtype,
|
||||||
|
cache_position=cache_position,
|
||||||
|
batch_size=input_tensor.shape[0],
|
||||||
|
)
|
||||||
|
|
||||||
|
# We take care to integrate alibi bias in the causal_mask here
|
||||||
|
if head_mask is None and alibi is not None:
|
||||||
|
alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:])
|
||||||
|
causal_mask = torch.masked_fill(
|
||||||
|
alibi / math.sqrt(self.config.hidden_size // self.num_heads),
|
||||||
|
causal_mask < -1,
|
||||||
|
min_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.config._attn_implementation == "sdpa"
|
||||||
|
and attention_mask is not None
|
||||||
|
and attention_mask.device.type == "cuda"
|
||||||
|
and not output_attentions
|
||||||
|
):
|
||||||
|
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||||
|
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||||
|
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||||
|
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||||
|
|
||||||
|
return causal_mask
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"The Falcon Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings).",
|
"The Falcon Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings).",
|
||||||
@ -1097,23 +1240,22 @@ class FalconForCausalLM(FalconPreTrainedModel):
|
|||||||
def prepare_inputs_for_generation(
|
def prepare_inputs_for_generation(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor,
|
input_ids: torch.LongTensor,
|
||||||
past_key_values: Optional[torch.Tensor] = None,
|
past_key_values: Optional[Union[Cache, torch.Tensor]] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.Tensor] = None,
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: bool = True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||||
|
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
||||||
|
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||||
if past_key_values is not None:
|
if past_key_values is not None:
|
||||||
past_length = past_key_values[0][0].shape[2]
|
if inputs_embeds is not None: # Exception 1
|
||||||
|
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||||
# Some generation methods already pass only the last input ID
|
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||||
if input_ids.shape[1] > past_length:
|
input_ids = input_ids[:, cache_position]
|
||||||
remove_prefix_length = past_length
|
|
||||||
else:
|
|
||||||
# Default to old behavior: keep only final ID
|
|
||||||
remove_prefix_length = input_ids.shape[1] - 1
|
|
||||||
|
|
||||||
input_ids = input_ids[:, remove_prefix_length:]
|
|
||||||
|
|
||||||
# Note: versions of Falcon with alibi do not use position_ids. It is used with RoPE.
|
# Note: versions of Falcon with alibi do not use position_ids. It is used with RoPE.
|
||||||
if not self.transformer.use_alibi and attention_mask is not None and position_ids is None:
|
if not self.transformer.use_alibi and attention_mask is not None and position_ids is None:
|
||||||
@ -1123,16 +1265,43 @@ class FalconForCausalLM(FalconPreTrainedModel):
|
|||||||
if past_key_values:
|
if past_key_values:
|
||||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||||
|
|
||||||
if inputs_embeds is not None and past_key_values is None:
|
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
|
||||||
|
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
|
||||||
|
|
||||||
|
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||||
|
if inputs_embeds is not None and cache_position[0] == 0:
|
||||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||||
else:
|
else:
|
||||||
model_inputs = {"input_ids": input_ids}
|
model_inputs = {"input_ids": input_ids}
|
||||||
|
|
||||||
|
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
|
||||||
|
if inputs_embeds is not None:
|
||||||
|
batch_size, sequence_length = inputs_embeds.shape
|
||||||
|
device = inputs_embeds.device
|
||||||
|
else:
|
||||||
|
batch_size, sequence_length = input_ids.shape
|
||||||
|
device = input_ids.device
|
||||||
|
|
||||||
|
dtype = self.lm_head.weight.dtype
|
||||||
|
min_dtype = torch.finfo(dtype).min
|
||||||
|
|
||||||
|
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
|
attention_mask,
|
||||||
|
sequence_length=sequence_length,
|
||||||
|
target_length=past_key_values.get_max_length(),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
min_dtype=min_dtype,
|
||||||
|
cache_position=cache_position,
|
||||||
|
batch_size=batch_size,
|
||||||
|
)
|
||||||
|
|
||||||
model_inputs.update(
|
model_inputs.update(
|
||||||
{
|
{
|
||||||
"position_ids": position_ids,
|
"position_ids": position_ids,
|
||||||
|
"cache_position": cache_position,
|
||||||
"past_key_values": past_key_values,
|
"past_key_values": past_key_values,
|
||||||
"use_cache": kwargs.get("use_cache"),
|
"use_cache": use_cache,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -1147,7 +1316,7 @@ class FalconForCausalLM(FalconPreTrainedModel):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
head_mask: Optional[torch.Tensor] = None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
@ -1157,6 +1326,7 @@ class FalconForCausalLM(FalconPreTrainedModel):
|
|||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
|
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
|
||||||
r"""
|
r"""
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
@ -1178,6 +1348,7 @@ class FalconForCausalLM(FalconPreTrainedModel):
|
|||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
)
|
)
|
||||||
hidden_states = transformer_outputs[0]
|
hidden_states = transformer_outputs[0]
|
||||||
|
|
||||||
|
@ -25,6 +25,7 @@ from torch import nn
|
|||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
|
from ...cache_utils import Cache, DynamicCache
|
||||||
from ...file_utils import ModelOutput
|
from ...file_utils import ModelOutput
|
||||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
|
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
@ -124,13 +125,20 @@ class GitEmbeddings(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class GitSelfAttention(nn.Module):
|
class GitSelfAttention(nn.Module):
|
||||||
def __init__(self, config, position_embedding_type=None):
|
def __init__(self, config, position_embedding_type=None, layer_idx=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
||||||
f"heads ({config.num_attention_heads})"
|
f"heads ({config.num_attention_heads})"
|
||||||
)
|
)
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
if layer_idx is None:
|
||||||
|
logger.warning_once(
|
||||||
|
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
|
||||||
|
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
|
||||||
|
"when creating this class."
|
||||||
|
)
|
||||||
|
|
||||||
self.num_attention_heads = config.num_attention_heads
|
self.num_attention_heads = config.num_attention_heads
|
||||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||||
@ -161,46 +169,31 @@ class GitSelfAttention(nn.Module):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask: Optional[torch.FloatTensor] = None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
past_key_value: Optional[Cache] = None,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
pixel_values_present: Optional[bool] = False,
|
pixel_values_present: Optional[bool] = False,
|
||||||
) -> Tuple[torch.Tensor]:
|
) -> Tuple[torch.Tensor]:
|
||||||
mixed_query_layer = self.query(hidden_states)
|
mixed_query_layer = self.query(hidden_states)
|
||||||
|
|
||||||
cutoff = self.image_patch_tokens if pixel_values_present else 0
|
cutoff = self.image_patch_tokens if pixel_values_present else 0
|
||||||
|
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||||
|
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
# NOTE: like in other caches, we store the text component. In GIT it means we discard the image component.
|
||||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
key_layer_past, value_layer_past = past_key_value.update(
|
||||||
key_layer = torch.cat([key_layer[:, :, :cutoff, :], past_key_value[0], key_layer[:, :, -1:, :]], dim=2)
|
key_layer[:, :, cutoff:, :], value_layer[:, :, cutoff:, :], self.layer_idx
|
||||||
value_layer = torch.cat(
|
|
||||||
[value_layer[:, :, :cutoff, :], past_key_value[1], value_layer[:, :, -1:, :]], dim=2
|
|
||||||
)
|
)
|
||||||
else:
|
key_layer = torch.cat([key_layer[:, :, :cutoff, :], key_layer_past], dim=2)
|
||||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
value_layer = torch.cat([value_layer[:, :, :cutoff, :], value_layer_past], dim=2)
|
||||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
|
||||||
|
|
||||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||||
|
|
||||||
use_cache = past_key_value is not None
|
|
||||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
|
||||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
|
||||||
# key/value_states (first "if" case)
|
|
||||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
|
||||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
|
||||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
|
||||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
|
||||||
# NOTE: like in other caches, we store the text component. In GIT it means we discard the image component.
|
|
||||||
past_key_value = (
|
|
||||||
key_layer[:, :, cutoff:, :],
|
|
||||||
value_layer[:, :, cutoff:, :],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||||
|
|
||||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||||
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
|
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
|
||||||
if use_cache:
|
if past_key_value is not None:
|
||||||
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
|
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
|
||||||
-1, 1
|
-1, 1
|
||||||
)
|
)
|
||||||
@ -269,11 +262,10 @@ GIT_SELF_ATTENTION_CLASSES = {
|
|||||||
|
|
||||||
|
|
||||||
class GitAttention(nn.Module):
|
class GitAttention(nn.Module):
|
||||||
# Copied from transformers.models.bert.modeling_bert.BertAttention.__init__ with Bert->Git,BERT->GIT
|
def __init__(self, config, position_embedding_type=None, layer_idx=None):
|
||||||
def __init__(self, config, position_embedding_type=None):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.self = GIT_SELF_ATTENTION_CLASSES[config._attn_implementation](
|
self.self = GIT_SELF_ATTENTION_CLASSES[config._attn_implementation](
|
||||||
config, position_embedding_type=position_embedding_type
|
config, position_embedding_type=position_embedding_type, layer_idx=layer_idx
|
||||||
)
|
)
|
||||||
self.output = GitSelfOutput(config)
|
self.output = GitSelfOutput(config)
|
||||||
self.pruned_heads = set()
|
self.pruned_heads = set()
|
||||||
@ -302,7 +294,7 @@ class GitAttention(nn.Module):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask: Optional[torch.FloatTensor] = None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
past_key_value: Optional[Cache] = None,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
pixel_values_present: Optional[bool] = False,
|
pixel_values_present: Optional[bool] = False,
|
||||||
) -> Tuple[torch.Tensor]:
|
) -> Tuple[torch.Tensor]:
|
||||||
@ -351,11 +343,11 @@ class GitOutput(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class GitLayer(nn.Module):
|
class GitLayer(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config, layer_idx=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||||
self.seq_len_dim = 1
|
self.seq_len_dim = 1
|
||||||
self.attention = GitAttention(config)
|
self.attention = GitAttention(config, layer_idx=layer_idx)
|
||||||
self.intermediate = GitIntermediate(config)
|
self.intermediate = GitIntermediate(config)
|
||||||
self.output = GitOutput(config)
|
self.output = GitOutput(config)
|
||||||
|
|
||||||
@ -364,18 +356,17 @@ class GitLayer(nn.Module):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask: Optional[torch.FloatTensor] = None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
past_key_value: Optional[Cache] = None,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
pixel_values_present: Optional[bool] = False,
|
pixel_values_present: Optional[bool] = False,
|
||||||
) -> Tuple[torch.Tensor]:
|
) -> Tuple[torch.Tensor]:
|
||||||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||||
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
|
||||||
self_attention_outputs = self.attention(
|
self_attention_outputs = self.attention(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
head_mask,
|
head_mask,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
past_key_value=self_attn_past_key_value,
|
past_key_value=past_key_value,
|
||||||
pixel_values_present=pixel_values_present,
|
pixel_values_present=pixel_values_present,
|
||||||
)
|
)
|
||||||
attention_output = self_attention_outputs[0]
|
attention_output = self_attention_outputs[0]
|
||||||
@ -401,11 +392,10 @@ class GitLayer(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class GitEncoder(nn.Module):
|
class GitEncoder(nn.Module):
|
||||||
# Copied from transformers.models.bert.modeling_bert.BertEncoder.__init__ with Bert->Git
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.layer = nn.ModuleList([GitLayer(config) for _ in range(config.num_hidden_layers)])
|
self.layer = nn.ModuleList([GitLayer(config, i) for i in range(config.num_hidden_layers)])
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -413,7 +403,7 @@ class GitEncoder(nn.Module):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask: Optional[torch.FloatTensor] = None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.FloatTensor]]]] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
output_hidden_states: Optional[bool] = False,
|
output_hidden_states: Optional[bool] = False,
|
||||||
@ -427,16 +417,23 @@ class GitEncoder(nn.Module):
|
|||||||
)
|
)
|
||||||
use_cache = False
|
use_cache = False
|
||||||
|
|
||||||
|
use_legacy_cache = False
|
||||||
|
if use_cache and not isinstance(past_key_values, Cache) and not self.training:
|
||||||
|
use_legacy_cache = True
|
||||||
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
|
logger.warning_once(
|
||||||
|
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.45. "
|
||||||
|
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
|
||||||
|
)
|
||||||
|
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
all_self_attentions = () if output_attentions else None
|
all_self_attentions = () if output_attentions else None
|
||||||
|
next_decoder_cache = None
|
||||||
next_decoder_cache = () if use_cache else None
|
|
||||||
for i, layer_module in enumerate(self.layer):
|
for i, layer_module in enumerate(self.layer):
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
@ -444,7 +441,7 @@ class GitEncoder(nn.Module):
|
|||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
layer_head_mask,
|
layer_head_mask,
|
||||||
past_key_value,
|
past_key_values,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -452,26 +449,30 @@ class GitEncoder(nn.Module):
|
|||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
layer_head_mask,
|
layer_head_mask,
|
||||||
past_key_value,
|
past_key_values,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
pixel_values_present,
|
pixel_values_present,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
if use_cache:
|
if use_cache:
|
||||||
next_decoder_cache += (layer_outputs[-1],)
|
next_decoder_cache = layer_outputs[-1]
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||||
|
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
next_cache = None
|
||||||
|
if use_cache:
|
||||||
|
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(
|
return tuple(
|
||||||
v
|
v
|
||||||
for v in [
|
for v in [
|
||||||
hidden_states,
|
hidden_states,
|
||||||
next_decoder_cache,
|
next_cache,
|
||||||
all_hidden_states,
|
all_hidden_states,
|
||||||
all_self_attentions,
|
all_self_attentions,
|
||||||
]
|
]
|
||||||
@ -479,7 +480,7 @@ class GitEncoder(nn.Module):
|
|||||||
)
|
)
|
||||||
return BaseModelOutputWithPast(
|
return BaseModelOutputWithPast(
|
||||||
last_hidden_state=hidden_states,
|
last_hidden_state=hidden_states,
|
||||||
past_key_values=next_decoder_cache,
|
past_key_values=next_cache,
|
||||||
hidden_states=all_hidden_states,
|
hidden_states=all_hidden_states,
|
||||||
attentions=all_self_attentions,
|
attentions=all_self_attentions,
|
||||||
)
|
)
|
||||||
@ -494,6 +495,8 @@ class GitPreTrainedModel(PreTrainedModel):
|
|||||||
config_class = GitConfig
|
config_class = GitConfig
|
||||||
base_model_prefix = "git"
|
base_model_prefix = "git"
|
||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
|
_supports_cache_class = True
|
||||||
|
_supports_quantized_cache = True
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
"""Initialize the weights"""
|
"""Initialize the weights"""
|
||||||
@ -569,6 +572,23 @@ GIT_INPUTS_DOCSTRING = r"""
|
|||||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
||||||
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
||||||
model's internal embedding lookup matrix.
|
model's internal embedding lookup matrix.
|
||||||
|
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
|
||||||
|
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||||
|
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
||||||
|
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
||||||
|
|
||||||
|
Two formats are allowed:
|
||||||
|
- a [`~cache_utils.Cache`] instance;
|
||||||
|
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
||||||
|
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
|
||||||
|
cache format.
|
||||||
|
|
||||||
|
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
|
||||||
|
legacy cache format will be returned.
|
||||||
|
|
||||||
|
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
||||||
|
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
|
||||||
|
of shape `(batch_size, sequence_length)`.
|
||||||
output_attentions (`bool`, *optional*):
|
output_attentions (`bool`, *optional*):
|
||||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||||
tensors for more detail.
|
tensors for more detail.
|
||||||
@ -1136,19 +1156,13 @@ class GitModel(GitPreTrainedModel):
|
|||||||
pixel_values: Optional[torch.Tensor] = None,
|
pixel_values: Optional[torch.Tensor] = None,
|
||||||
head_mask: Optional[torch.Tensor] = None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPooling]:
|
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPooling]:
|
||||||
r"""
|
r"""
|
||||||
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
|
||||||
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
|
||||||
|
|
||||||
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
|
||||||
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
|
||||||
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
|
||||||
use_cache (`bool`, *optional*):
|
use_cache (`bool`, *optional*):
|
||||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
||||||
`past_key_values`).
|
`past_key_values`).
|
||||||
@ -1195,7 +1209,13 @@ class GitModel(GitPreTrainedModel):
|
|||||||
seq_length = input_shape[1]
|
seq_length = input_shape[1]
|
||||||
|
|
||||||
# past_key_values_length
|
# past_key_values_length
|
||||||
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
past_key_values_length = 0
|
||||||
|
if past_key_values is not None:
|
||||||
|
past_key_values_length = (
|
||||||
|
past_key_values[0][0].shape[2]
|
||||||
|
if not isinstance(past_key_values, Cache)
|
||||||
|
else past_key_values.get_seq_length()
|
||||||
|
)
|
||||||
|
|
||||||
# Prepare head mask if needed
|
# Prepare head mask if needed
|
||||||
# 1.0 in head_mask indicate we keep the head
|
# 1.0 in head_mask indicate we keep the head
|
||||||
@ -1327,7 +1347,7 @@ class GitForCausalLM(GitPreTrainedModel):
|
|||||||
head_mask: Optional[torch.Tensor] = None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
labels: Optional[torch.Tensor] = None,
|
labels: Optional[torch.Tensor] = None,
|
||||||
past_key_values: Optional[List[torch.Tensor]] = None,
|
past_key_values: Optional[Union[Cache, List[torch.Tensor]]] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
@ -1338,12 +1358,6 @@ class GitForCausalLM(GitPreTrainedModel):
|
|||||||
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
|
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
|
||||||
`[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
|
`[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
|
||||||
ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`
|
ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`
|
||||||
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
|
||||||
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
|
||||||
|
|
||||||
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
|
||||||
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
|
||||||
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
|
||||||
use_cache (`bool`, *optional*):
|
use_cache (`bool`, *optional*):
|
||||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
||||||
`past_key_values`).
|
`past_key_values`).
|
||||||
@ -1522,7 +1536,16 @@ class GitForCausalLM(GitPreTrainedModel):
|
|||||||
):
|
):
|
||||||
# cut decoder_input_ids if past_key_values is used
|
# cut decoder_input_ids if past_key_values is used
|
||||||
if past_key_values is not None:
|
if past_key_values is not None:
|
||||||
input_ids = input_ids[:, -1:]
|
past_length = past_key_values.get_seq_length()
|
||||||
|
|
||||||
|
# Some generation methods already pass only the last input ID
|
||||||
|
if input_ids.shape[1] > past_length:
|
||||||
|
remove_prefix_length = past_length
|
||||||
|
else:
|
||||||
|
# Default to old behavior: keep only final ID
|
||||||
|
remove_prefix_length = input_ids.shape[1] - 1
|
||||||
|
|
||||||
|
input_ids = input_ids[:, remove_prefix_length:]
|
||||||
|
|
||||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||||
input_shape = input_ids.shape
|
input_shape = input_ids.shape
|
||||||
|
@ -23,7 +23,8 @@ from torch import nn
|
|||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
from ...cache_utils import Cache, DynamicCache, StaticCache
|
||||||
|
from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutputWithPast,
|
BaseModelOutputWithPast,
|
||||||
BaseModelOutputWithPastAndCrossAttentions,
|
BaseModelOutputWithPastAndCrossAttentions,
|
||||||
@ -68,6 +69,60 @@ _CONFIG_FOR_DOC = "GPTNeoConfig"
|
|||||||
_CHECKPOINT_FOR_DOC = "EleutherAI/gpt-neo-1.3B"
|
_CHECKPOINT_FOR_DOC = "EleutherAI/gpt-neo-1.3B"
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
|
||||||
|
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
sequence_length: int,
|
||||||
|
target_length: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
min_dtype: float,
|
||||||
|
cache_position: torch.Tensor,
|
||||||
|
batch_size: int,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||||
|
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
attention_mask (`torch.Tensor`):
|
||||||
|
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
||||||
|
sequence_length (`int`):
|
||||||
|
The sequence length being processed.
|
||||||
|
target_length (`int`):
|
||||||
|
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
|
||||||
|
dtype (`torch.dtype`):
|
||||||
|
The dtype to use for the 4D attention mask.
|
||||||
|
device (`torch.device`):
|
||||||
|
The device to plcae the 4D attention mask on.
|
||||||
|
min_dtype (`float`):
|
||||||
|
The minimum value representable with the dtype `dtype`.
|
||||||
|
cache_position (`torch.Tensor`):
|
||||||
|
Indices depicting the position of the input sequence tokens in the sequence.
|
||||||
|
batch_size (`torch.Tensor`):
|
||||||
|
Batch size.
|
||||||
|
"""
|
||||||
|
if attention_mask is not None and attention_mask.dim() == 4:
|
||||||
|
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||||
|
causal_mask = attention_mask
|
||||||
|
else:
|
||||||
|
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
||||||
|
if sequence_length != 1:
|
||||||
|
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||||
|
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||||
|
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||||
|
if attention_mask is not None:
|
||||||
|
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||||
|
mask_length = attention_mask.shape[-1]
|
||||||
|
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
||||||
|
padding_mask = padding_mask == 0
|
||||||
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||||
|
padding_mask, min_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
return causal_mask
|
||||||
|
|
||||||
|
|
||||||
def load_tf_weights_in_gpt_neo(model, config, gpt_neo_checkpoint_path):
|
def load_tf_weights_in_gpt_neo(model, config, gpt_neo_checkpoint_path):
|
||||||
"""Load tf checkpoints in a pytorch model"""
|
"""Load tf checkpoints in a pytorch model"""
|
||||||
try:
|
try:
|
||||||
@ -149,7 +204,7 @@ def load_tf_weights_in_gpt_neo(model, config, gpt_neo_checkpoint_path):
|
|||||||
|
|
||||||
|
|
||||||
class GPTNeoSelfAttention(nn.Module):
|
class GPTNeoSelfAttention(nn.Module):
|
||||||
def __init__(self, config, attention_type):
|
def __init__(self, config, attention_type, layer_id=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
@ -170,6 +225,7 @@ class GPTNeoSelfAttention(nn.Module):
|
|||||||
self.attn_dropout = nn.Dropout(float(config.attention_dropout))
|
self.attn_dropout = nn.Dropout(float(config.attention_dropout))
|
||||||
self.resid_dropout = nn.Dropout(float(config.resid_dropout))
|
self.resid_dropout = nn.Dropout(float(config.resid_dropout))
|
||||||
self.is_causal = True
|
self.is_causal = True
|
||||||
|
self.layer_id = layer_id
|
||||||
|
|
||||||
self.embed_dim = config.hidden_size
|
self.embed_dim = config.hidden_size
|
||||||
self.num_heads = config.num_heads
|
self.num_heads = config.num_heads
|
||||||
@ -208,6 +264,7 @@ class GPTNeoSelfAttention(nn.Module):
|
|||||||
|
|
||||||
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
||||||
|
|
||||||
|
# Apply sliding window masking for local attention layers
|
||||||
query_length, key_length = query.size(-2), key.size(-2)
|
query_length, key_length = query.size(-2), key.size(-2)
|
||||||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
|
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
|
||||||
mask_value = torch.finfo(attn_weights.dtype).min
|
mask_value = torch.finfo(attn_weights.dtype).min
|
||||||
@ -216,9 +273,9 @@ class GPTNeoSelfAttention(nn.Module):
|
|||||||
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
|
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
|
||||||
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
|
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None: # no matter the length, we just slice it
|
||||||
# Apply the attention mask
|
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||||
attn_weights = attn_weights + attention_mask
|
attn_weights = attn_weights + causal_mask
|
||||||
|
|
||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||||
attn_weights = attn_weights.to(value.dtype)
|
attn_weights = attn_weights.to(value.dtype)
|
||||||
@ -240,6 +297,7 @@ class GPTNeoSelfAttention(nn.Module):
|
|||||||
head_mask=None,
|
head_mask=None,
|
||||||
use_cache=False,
|
use_cache=False,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
|
cache_position=None,
|
||||||
):
|
):
|
||||||
query = self.q_proj(hidden_states)
|
query = self.q_proj(hidden_states)
|
||||||
key = self.k_proj(hidden_states)
|
key = self.k_proj(hidden_states)
|
||||||
@ -250,15 +308,8 @@ class GPTNeoSelfAttention(nn.Module):
|
|||||||
value = self._split_heads(value, self.num_heads, self.head_dim)
|
value = self._split_heads(value, self.num_heads, self.head_dim)
|
||||||
|
|
||||||
if layer_past is not None:
|
if layer_past is not None:
|
||||||
past_key = layer_past[0]
|
cache_kwargs = {"cache_position": cache_position}
|
||||||
past_value = layer_past[1]
|
key, value = layer_past.update(key, value, self.layer_id, cache_kwargs)
|
||||||
key = torch.cat((past_key, key), dim=-2)
|
|
||||||
value = torch.cat((past_value, value), dim=-2)
|
|
||||||
|
|
||||||
if use_cache is True:
|
|
||||||
present = (key, value)
|
|
||||||
else:
|
|
||||||
present = None
|
|
||||||
|
|
||||||
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
||||||
|
|
||||||
@ -266,11 +317,11 @@ class GPTNeoSelfAttention(nn.Module):
|
|||||||
attn_output = self.out_proj(attn_output)
|
attn_output = self.out_proj(attn_output)
|
||||||
attn_output = self.resid_dropout(attn_output)
|
attn_output = self.resid_dropout(attn_output)
|
||||||
|
|
||||||
outputs = (attn_output, present)
|
outputs = (attn_output, layer_past)
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
outputs += (attn_weights,)
|
outputs += (attn_weights,)
|
||||||
|
|
||||||
return outputs # a, present, (attentions)
|
return outputs # a, past_kv, (attentions)
|
||||||
|
|
||||||
|
|
||||||
class GPTNeoFlashAttention2(GPTNeoSelfAttention):
|
class GPTNeoFlashAttention2(GPTNeoSelfAttention):
|
||||||
@ -297,6 +348,7 @@ class GPTNeoFlashAttention2(GPTNeoSelfAttention):
|
|||||||
head_mask=None,
|
head_mask=None,
|
||||||
use_cache=False,
|
use_cache=False,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
|
cache_position=None,
|
||||||
):
|
):
|
||||||
bsz, _, _ = hidden_states.size()
|
bsz, _, _ = hidden_states.size()
|
||||||
|
|
||||||
@ -309,15 +361,8 @@ class GPTNeoFlashAttention2(GPTNeoSelfAttention):
|
|||||||
value = self._split_heads(value, self.num_heads, self.head_dim)
|
value = self._split_heads(value, self.num_heads, self.head_dim)
|
||||||
|
|
||||||
if layer_past is not None:
|
if layer_past is not None:
|
||||||
past_key = layer_past[0]
|
cache_kwargs = {"cache_position": cache_position}
|
||||||
past_value = layer_past[1]
|
key, value = layer_past.update(key, value, self.layer_id, cache_kwargs)
|
||||||
key = torch.cat((past_key, key), dim=-2)
|
|
||||||
value = torch.cat((past_value, value), dim=-2)
|
|
||||||
|
|
||||||
if use_cache is True:
|
|
||||||
present = (key, value)
|
|
||||||
else:
|
|
||||||
present = None
|
|
||||||
|
|
||||||
query_length = query.shape[2]
|
query_length = query.shape[2]
|
||||||
tgt_len = key.shape[2]
|
tgt_len = key.shape[2]
|
||||||
@ -330,6 +375,9 @@ class GPTNeoFlashAttention2(GPTNeoSelfAttention):
|
|||||||
|
|
||||||
attn_dropout = self.config.attention_dropout if self.training else 0.0
|
attn_dropout = self.config.attention_dropout if self.training else 0.0
|
||||||
|
|
||||||
|
if attention_mask is not None: # no matter the length, we just slice it
|
||||||
|
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||||
|
|
||||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||||
# cast them back in the correct dtype just to be sure everything works as expected.
|
# cast them back in the correct dtype just to be sure everything works as expected.
|
||||||
@ -371,7 +419,7 @@ class GPTNeoFlashAttention2(GPTNeoSelfAttention):
|
|||||||
attn_output = self.out_proj(attn_weights_reshaped)
|
attn_output = self.out_proj(attn_weights_reshaped)
|
||||||
attn_output = self.resid_dropout(attn_output)
|
attn_output = self.resid_dropout(attn_output)
|
||||||
|
|
||||||
outputs = (attn_output, present)
|
outputs = (attn_output, layer_past)
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
outputs += (attn_weights_reshaped,)
|
outputs += (attn_weights_reshaped,)
|
||||||
|
|
||||||
@ -392,7 +440,9 @@ class GPTNeoAttention(nn.Module):
|
|||||||
self.attention_type = self.attention_layers[layer_id]
|
self.attention_type = self.attention_layers[layer_id]
|
||||||
|
|
||||||
if self.attention_type in ["global", "local"]:
|
if self.attention_type in ["global", "local"]:
|
||||||
self.attention = GPT_NEO_ATTENTION_CLASSES[config._attn_implementation](config, self.attention_type)
|
self.attention = GPT_NEO_ATTENTION_CLASSES[config._attn_implementation](
|
||||||
|
config, self.attention_type, layer_id
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Only attn layer types 'global' and 'local' exist, but got `config.attention_layers`: "
|
"Only attn layer types 'global' and 'local' exist, but got `config.attention_layers`: "
|
||||||
@ -407,6 +457,7 @@ class GPTNeoAttention(nn.Module):
|
|||||||
head_mask=None,
|
head_mask=None,
|
||||||
use_cache=False,
|
use_cache=False,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
|
cache_position=None,
|
||||||
):
|
):
|
||||||
return self.attention(
|
return self.attention(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@ -415,6 +466,7 @@ class GPTNeoAttention(nn.Module):
|
|||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
|
cache_position=cache_position,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -436,7 +488,7 @@ class GPTNeoMLP(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class GPTNeoBlock(nn.Module):
|
class GPTNeoBlock(nn.Module):
|
||||||
def __init__(self, config, layer_id):
|
def __init__(self, config, layer_id=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
hidden_size = config.hidden_size
|
hidden_size = config.hidden_size
|
||||||
inner_dim = config.intermediate_size if config.intermediate_size is not None else 4 * hidden_size
|
inner_dim = config.intermediate_size if config.intermediate_size is not None else 4 * hidden_size
|
||||||
@ -453,6 +505,7 @@ class GPTNeoBlock(nn.Module):
|
|||||||
head_mask=None,
|
head_mask=None,
|
||||||
use_cache=False,
|
use_cache=False,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
|
cache_position=None,
|
||||||
):
|
):
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.ln_1(hidden_states)
|
hidden_states = self.ln_1(hidden_states)
|
||||||
@ -463,6 +516,7 @@ class GPTNeoBlock(nn.Module):
|
|||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
|
cache_position=cache_position,
|
||||||
)
|
)
|
||||||
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
|
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
|
||||||
outputs = attn_outputs[1:]
|
outputs = attn_outputs[1:]
|
||||||
@ -480,7 +534,7 @@ class GPTNeoBlock(nn.Module):
|
|||||||
else:
|
else:
|
||||||
outputs = (hidden_states,) + outputs[1:]
|
outputs = (hidden_states,) + outputs[1:]
|
||||||
|
|
||||||
return outputs # hidden_states, present, (attentions, cross_attentions)
|
return outputs # hidden_states, past_kv, attentions
|
||||||
|
|
||||||
|
|
||||||
class GPTNeoPreTrainedModel(PreTrainedModel):
|
class GPTNeoPreTrainedModel(PreTrainedModel):
|
||||||
@ -496,6 +550,9 @@ class GPTNeoPreTrainedModel(PreTrainedModel):
|
|||||||
_no_split_modules = ["GPTNeoBlock"]
|
_no_split_modules = ["GPTNeoBlock"]
|
||||||
_skip_keys_device_placement = "past_key_values"
|
_skip_keys_device_placement = "past_key_values"
|
||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
|
_supports_cache_class = True
|
||||||
|
_supports_quantized_cache = True
|
||||||
|
_supports_static_cache = False # TODO: needs a HybridCache
|
||||||
|
|
||||||
def __init__(self, *inputs, **kwargs):
|
def __init__(self, *inputs, **kwargs):
|
||||||
super().__init__(*inputs, **kwargs)
|
super().__init__(*inputs, **kwargs)
|
||||||
@ -547,10 +604,23 @@ GPT_NEO_INPUTS_DOCSTRING = r"""
|
|||||||
[`PreTrainedTokenizer.__call__`] for details.
|
[`PreTrainedTokenizer.__call__`] for details.
|
||||||
|
|
||||||
[What are input IDs?](../glossary#input-ids)
|
[What are input IDs?](../glossary#input-ids)
|
||||||
past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.num_layers`):
|
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
|
||||||
Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
|
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||||
`past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
|
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
||||||
their past given to this model should not be passed as `input_ids` as they have already been computed.
|
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
||||||
|
|
||||||
|
Two formats are allowed:
|
||||||
|
- a [`~cache_utils.Cache`] instance;
|
||||||
|
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
||||||
|
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
|
||||||
|
cache format.
|
||||||
|
|
||||||
|
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
|
||||||
|
legacy cache format will be returned.
|
||||||
|
|
||||||
|
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
||||||
|
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
|
||||||
|
of shape `(batch_size, sequence_length)`.
|
||||||
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||||
|
|
||||||
@ -595,6 +665,10 @@ GPT_NEO_INPUTS_DOCSTRING = r"""
|
|||||||
more detail.
|
more detail.
|
||||||
return_dict (`bool`, *optional*):
|
return_dict (`bool`, *optional*):
|
||||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
|
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||||
|
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
|
||||||
|
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
|
||||||
|
the complete sequence length.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -611,7 +685,6 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
|
|||||||
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
||||||
self.drop = nn.Dropout(float(config.embed_dropout))
|
self.drop = nn.Dropout(float(config.embed_dropout))
|
||||||
self.h = nn.ModuleList([GPTNeoBlock(config, layer_id=i) for i in range(config.num_layers)])
|
self.h = nn.ModuleList([GPTNeoBlock(config, layer_id=i) for i in range(config.num_layers)])
|
||||||
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
|
||||||
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
@ -633,7 +706,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.Tensor] = None,
|
input_ids: Optional[torch.Tensor] = None,
|
||||||
past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
|
past_key_values: Optional[Union[Cache, Tuple[torch.FloatTensor]]] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
token_type_ids: Optional[torch.Tensor] = None,
|
token_type_ids: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.Tensor] = None,
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
@ -643,6 +716,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
|
|||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
|
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
@ -651,58 +725,10 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
|
|||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
if input_ids is not None and inputs_embeds is not None:
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
raise ValueError(
|
||||||
elif input_ids is not None:
|
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
||||||
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
)
|
||||||
input_shape = input_ids.size()
|
|
||||||
input_ids = input_ids.view(-1, input_shape[-1])
|
|
||||||
elif inputs_embeds is not None:
|
|
||||||
input_shape = inputs_embeds.size()[:-1]
|
|
||||||
else:
|
|
||||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
|
||||||
|
|
||||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
|
||||||
|
|
||||||
if token_type_ids is not None:
|
|
||||||
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
|
||||||
|
|
||||||
if past_key_values is None:
|
|
||||||
past_length = 0
|
|
||||||
past_key_values = tuple([None] * len(self.h))
|
|
||||||
else:
|
|
||||||
past_length = past_key_values[0][0].size(-2)
|
|
||||||
|
|
||||||
if position_ids is None:
|
|
||||||
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
|
||||||
position_ids = position_ids.unsqueeze(0)
|
|
||||||
|
|
||||||
# Prepare head mask if needed
|
|
||||||
# 1.0 in head_mask indicate we keep the head
|
|
||||||
# attention_probs has shape bsz x num_heads x N x N
|
|
||||||
# head_mask has shape n_layer x batch x num_heads x N x N
|
|
||||||
head_mask = self.get_head_mask(head_mask, self.config.num_layers)
|
|
||||||
|
|
||||||
if inputs_embeds is None:
|
|
||||||
inputs_embeds = self.wte(input_ids)
|
|
||||||
position_embeds = self.wpe(position_ids)
|
|
||||||
hidden_states = inputs_embeds + position_embeds
|
|
||||||
|
|
||||||
# Attention mask.
|
|
||||||
if self._use_flash_attention_2:
|
|
||||||
# 2d mask is passed through the layers
|
|
||||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
|
||||||
else:
|
|
||||||
# 4d mask is passed through the layers
|
|
||||||
attention_mask = _prepare_4d_causal_attention_mask(attention_mask, input_shape, inputs_embeds, past_length)
|
|
||||||
|
|
||||||
if token_type_ids is not None:
|
|
||||||
token_type_embeds = self.wte(token_type_ids)
|
|
||||||
hidden_states = hidden_states + token_type_embeds
|
|
||||||
|
|
||||||
hidden_states = self.drop(hidden_states)
|
|
||||||
|
|
||||||
output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
|
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
if use_cache:
|
if use_cache:
|
||||||
@ -711,10 +737,51 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
|
|||||||
)
|
)
|
||||||
use_cache = False
|
use_cache = False
|
||||||
|
|
||||||
presents = () if use_cache else None
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.wte(input_ids)
|
||||||
|
|
||||||
|
use_legacy_cache = False
|
||||||
|
if use_cache and not isinstance(past_key_values, Cache) and not self.training:
|
||||||
|
use_legacy_cache = True
|
||||||
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
|
if not self.training:
|
||||||
|
logger.warning_once(
|
||||||
|
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.45. "
|
||||||
|
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
|
||||||
|
)
|
||||||
|
|
||||||
|
seq_length = inputs_embeds.shape[1]
|
||||||
|
if cache_position is None:
|
||||||
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
|
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=inputs_embeds.device)
|
||||||
|
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = cache_position.unsqueeze(0)
|
||||||
|
|
||||||
|
causal_mask = self._update_causal_mask(
|
||||||
|
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare head mask if needed
|
||||||
|
# 1.0 in head_mask indicate we keep the head
|
||||||
|
# attention_probs has shape bsz x num_heads x N x N
|
||||||
|
# head_mask has shape n_layer x batch x num_heads x N x N
|
||||||
|
head_mask = self.get_head_mask(head_mask, self.config.num_layers)
|
||||||
|
position_embeds = self.wpe(position_ids)
|
||||||
|
hidden_states = inputs_embeds + position_embeds
|
||||||
|
|
||||||
|
if token_type_ids is not None:
|
||||||
|
token_type_ids = token_type_ids.view(-1, seq_length)
|
||||||
|
token_type_embeds = self.wte(token_type_ids)
|
||||||
|
hidden_states = hidden_states + token_type_embeds
|
||||||
|
|
||||||
|
hidden_states = self.drop(hidden_states)
|
||||||
|
output_shape = (-1, seq_length, hidden_states.size(-1))
|
||||||
|
|
||||||
|
next_decoder_cache = None
|
||||||
all_self_attentions = () if output_attentions else None
|
all_self_attentions = () if output_attentions else None
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
for i, block in enumerate(self.h):
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
@ -723,24 +790,26 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
|
|||||||
block.__call__,
|
block.__call__,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
None,
|
None,
|
||||||
attention_mask,
|
causal_mask,
|
||||||
head_mask[i],
|
head_mask[i],
|
||||||
use_cache,
|
use_cache,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
|
cache_position,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
outputs = block(
|
outputs = block(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
layer_past=layer_past,
|
layer_past=past_key_values,
|
||||||
attention_mask=attention_mask,
|
attention_mask=causal_mask,
|
||||||
head_mask=head_mask[i],
|
head_mask=head_mask[i],
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
|
cache_position=cache_position,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
hidden_states = outputs[0]
|
||||||
if use_cache is True:
|
if use_cache:
|
||||||
presents = presents + (outputs[1],)
|
next_decoder_cache = outputs[1]
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
||||||
@ -752,16 +821,94 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
next_cache = None
|
||||||
|
if use_cache:
|
||||||
|
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
return tuple(
|
||||||
|
v for v in [hidden_states, next_cache, all_hidden_states, all_self_attentions] if v is not None
|
||||||
|
)
|
||||||
|
|
||||||
return BaseModelOutputWithPast(
|
return BaseModelOutputWithPast(
|
||||||
last_hidden_state=hidden_states,
|
last_hidden_state=hidden_states,
|
||||||
past_key_values=presents,
|
past_key_values=next_cache,
|
||||||
hidden_states=all_hidden_states,
|
hidden_states=all_hidden_states,
|
||||||
attentions=all_self_attentions,
|
attentions=all_self_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
|
||||||
|
def _update_causal_mask(
|
||||||
|
self,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
input_tensor: torch.Tensor,
|
||||||
|
cache_position: torch.Tensor,
|
||||||
|
past_key_values: Cache,
|
||||||
|
output_attentions: bool,
|
||||||
|
):
|
||||||
|
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
|
||||||
|
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
|
||||||
|
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
|
||||||
|
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
|
||||||
|
|
||||||
|
if self.config._attn_implementation == "flash_attention_2":
|
||||||
|
if attention_mask is not None and 0.0 in attention_mask:
|
||||||
|
return attention_mask
|
||||||
|
return None
|
||||||
|
|
||||||
|
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||||
|
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||||
|
# to infer the attention mask.
|
||||||
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
|
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||||
|
|
||||||
|
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||||
|
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
|
||||||
|
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||||
|
attention_mask,
|
||||||
|
inputs_embeds=input_tensor,
|
||||||
|
past_key_values_length=past_seen_tokens,
|
||||||
|
is_training=self.training,
|
||||||
|
):
|
||||||
|
return None
|
||||||
|
|
||||||
|
dtype, device = input_tensor.dtype, input_tensor.device
|
||||||
|
min_dtype = torch.finfo(dtype).min
|
||||||
|
sequence_length = input_tensor.shape[1]
|
||||||
|
if using_static_cache:
|
||||||
|
target_length = past_key_values.get_max_length()
|
||||||
|
else:
|
||||||
|
target_length = (
|
||||||
|
attention_mask.shape[-1]
|
||||||
|
if isinstance(attention_mask, torch.Tensor)
|
||||||
|
else past_seen_tokens + sequence_length + 1
|
||||||
|
)
|
||||||
|
|
||||||
|
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||||
|
causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
|
attention_mask,
|
||||||
|
sequence_length=sequence_length,
|
||||||
|
target_length=target_length,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
min_dtype=min_dtype,
|
||||||
|
cache_position=cache_position,
|
||||||
|
batch_size=input_tensor.shape[0],
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.config._attn_implementation == "sdpa"
|
||||||
|
and attention_mask is not None
|
||||||
|
and attention_mask.device.type == "cuda"
|
||||||
|
and not output_attentions
|
||||||
|
):
|
||||||
|
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||||
|
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||||
|
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||||
|
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||||
|
|
||||||
|
return causal_mask
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"""
|
"""
|
||||||
@ -787,26 +934,30 @@ class GPTNeoForCausalLM(GPTNeoPreTrainedModel):
|
|||||||
def set_output_embeddings(self, new_embeddings):
|
def set_output_embeddings(self, new_embeddings):
|
||||||
self.lm_head = new_embeddings
|
self.lm_head = new_embeddings
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
|
def prepare_inputs_for_generation(
|
||||||
token_type_ids = kwargs.get("token_type_ids", None)
|
self,
|
||||||
# Omit tokens covered by past_key_values
|
input_ids,
|
||||||
if past_key_values:
|
attention_mask=None,
|
||||||
past_length = past_key_values[0][0].shape[2]
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
past_key_values=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
cache_position=None,
|
||||||
|
use_cache=True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||||
|
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
||||||
|
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||||
|
if past_key_values is not None:
|
||||||
|
if inputs_embeds is not None: # Exception 1
|
||||||
|
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||||
|
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||||
|
input_ids = input_ids[:, cache_position]
|
||||||
|
|
||||||
# Some generation methods already pass only the last input ID
|
|
||||||
if input_ids.shape[1] > past_length:
|
|
||||||
remove_prefix_length = past_length
|
|
||||||
else:
|
|
||||||
# Default to old behavior: keep only final ID
|
|
||||||
remove_prefix_length = input_ids.shape[1] - 1
|
|
||||||
|
|
||||||
input_ids = input_ids[:, remove_prefix_length:]
|
|
||||||
if token_type_ids is not None:
|
if token_type_ids is not None:
|
||||||
token_type_ids = token_type_ids[:, -input_ids.shape[1] :]
|
token_type_ids = token_type_ids[:, -input_ids.shape[1] :]
|
||||||
|
|
||||||
attention_mask = kwargs.get("attention_mask", None)
|
|
||||||
position_ids = kwargs.get("position_ids", None)
|
|
||||||
|
|
||||||
if attention_mask is not None and position_ids is None:
|
if attention_mask is not None and position_ids is None:
|
||||||
# create position_ids on the fly for batch generation
|
# create position_ids on the fly for batch generation
|
||||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||||
@ -814,22 +965,47 @@ class GPTNeoForCausalLM(GPTNeoPreTrainedModel):
|
|||||||
if past_key_values:
|
if past_key_values:
|
||||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||||
|
|
||||||
|
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
|
||||||
|
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
|
||||||
|
|
||||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||||
if inputs_embeds is not None and past_key_values is None:
|
if inputs_embeds is not None and cache_position[0] == 0:
|
||||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||||
else:
|
else:
|
||||||
model_inputs = {"input_ids": input_ids}
|
model_inputs = {"input_ids": input_ids}
|
||||||
|
|
||||||
|
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
|
||||||
|
if inputs_embeds is not None:
|
||||||
|
batch_size, sequence_length = inputs_embeds.shape
|
||||||
|
device = inputs_embeds.device
|
||||||
|
else:
|
||||||
|
batch_size, sequence_length = input_ids.shape
|
||||||
|
device = input_ids.device
|
||||||
|
|
||||||
|
dtype = self.lm_head.weight.dtype
|
||||||
|
min_dtype = torch.finfo(dtype).min
|
||||||
|
|
||||||
|
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
|
attention_mask,
|
||||||
|
sequence_length=sequence_length,
|
||||||
|
target_length=past_key_values.get_max_length(),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
min_dtype=min_dtype,
|
||||||
|
cache_position=cache_position,
|
||||||
|
batch_size=batch_size,
|
||||||
|
)
|
||||||
|
|
||||||
model_inputs.update(
|
model_inputs.update(
|
||||||
{
|
{
|
||||||
"past_key_values": past_key_values,
|
|
||||||
"use_cache": kwargs.get("use_cache"),
|
|
||||||
"position_ids": position_ids,
|
"position_ids": position_ids,
|
||||||
"attention_mask": attention_mask,
|
"cache_position": cache_position,
|
||||||
|
"past_key_values": past_key_values,
|
||||||
|
"use_cache": use_cache,
|
||||||
"token_type_ids": token_type_ids,
|
"token_type_ids": token_type_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(GPT_NEO_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(GPT_NEO_INPUTS_DOCSTRING)
|
||||||
@ -841,7 +1017,7 @@ class GPTNeoForCausalLM(GPTNeoPreTrainedModel):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.Tensor] = None,
|
input_ids: Optional[torch.Tensor] = None,
|
||||||
past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
|
past_key_values: Optional[Union[Cache, Tuple[torch.FloatTensor]]] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
token_type_ids: Optional[torch.Tensor] = None,
|
token_type_ids: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.Tensor] = None,
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
@ -852,6 +1028,7 @@ class GPTNeoForCausalLM(GPTNeoPreTrainedModel):
|
|||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
|
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
|
||||||
r"""
|
r"""
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
@ -873,6 +1050,7 @@ class GPTNeoForCausalLM(GPTNeoPreTrainedModel):
|
|||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
)
|
)
|
||||||
hidden_states = transformer_outputs[0]
|
hidden_states = transformer_outputs[0]
|
||||||
|
|
||||||
@ -957,7 +1135,7 @@ class GPTNeoForSequenceClassification(GPTNeoPreTrainedModel):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.Tensor] = None,
|
input_ids: Optional[torch.Tensor] = None,
|
||||||
past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
|
past_key_values: Optional[Union[Cache, Tuple[torch.FloatTensor]]] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
token_type_ids: Optional[torch.Tensor] = None,
|
token_type_ids: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.Tensor] = None,
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
@ -1081,7 +1259,7 @@ class GPTNeoForTokenClassification(GPTNeoPreTrainedModel):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor]]]] = None,
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
token_type_ids: Optional[torch.LongTensor] = None,
|
token_type_ids: Optional[torch.LongTensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
@ -23,13 +23,14 @@ from torch import nn
|
|||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
|
from ...cache_utils import Cache, DynamicCache, StaticCache
|
||||||
from ...file_utils import (
|
from ...file_utils import (
|
||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutputWithPast,
|
BaseModelOutputWithPast,
|
||||||
CausalLMOutputWithPast,
|
CausalLMOutputWithPast,
|
||||||
@ -52,6 +53,60 @@ _REAL_CHECKPOINT_FOR_DOC = "EleutherAI/gpt-neox-20b"
|
|||||||
_CONFIG_FOR_DOC = "GPTNeoXConfig"
|
_CONFIG_FOR_DOC = "GPTNeoXConfig"
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
|
||||||
|
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
sequence_length: int,
|
||||||
|
target_length: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
min_dtype: float,
|
||||||
|
cache_position: torch.Tensor,
|
||||||
|
batch_size: int,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||||
|
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
attention_mask (`torch.Tensor`):
|
||||||
|
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
||||||
|
sequence_length (`int`):
|
||||||
|
The sequence length being processed.
|
||||||
|
target_length (`int`):
|
||||||
|
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
|
||||||
|
dtype (`torch.dtype`):
|
||||||
|
The dtype to use for the 4D attention mask.
|
||||||
|
device (`torch.device`):
|
||||||
|
The device to plcae the 4D attention mask on.
|
||||||
|
min_dtype (`float`):
|
||||||
|
The minimum value representable with the dtype `dtype`.
|
||||||
|
cache_position (`torch.Tensor`):
|
||||||
|
Indices depicting the position of the input sequence tokens in the sequence.
|
||||||
|
batch_size (`torch.Tensor`):
|
||||||
|
Batch size.
|
||||||
|
"""
|
||||||
|
if attention_mask is not None and attention_mask.dim() == 4:
|
||||||
|
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||||
|
causal_mask = attention_mask
|
||||||
|
else:
|
||||||
|
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
||||||
|
if sequence_length != 1:
|
||||||
|
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||||
|
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||||
|
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||||
|
if attention_mask is not None:
|
||||||
|
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||||
|
mask_length = attention_mask.shape[-1]
|
||||||
|
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
||||||
|
padding_mask = padding_mask == 0
|
||||||
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||||
|
padding_mask, min_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
return causal_mask
|
||||||
|
|
||||||
|
|
||||||
class GPTNeoXPreTrainedModel(PreTrainedModel):
|
class GPTNeoXPreTrainedModel(PreTrainedModel):
|
||||||
"""
|
"""
|
||||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||||
@ -64,6 +119,9 @@ class GPTNeoXPreTrainedModel(PreTrainedModel):
|
|||||||
_no_split_modules = ["GPTNeoXLayer"]
|
_no_split_modules = ["GPTNeoXLayer"]
|
||||||
_skip_keys_device_placement = "past_key_values"
|
_skip_keys_device_placement = "past_key_values"
|
||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
|
_supports_cache_class = True
|
||||||
|
_supports_quantized_cache = True
|
||||||
|
_supports_static_cache = True
|
||||||
_supports_sdpa = True
|
_supports_sdpa = True
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
@ -82,7 +140,7 @@ class GPTNeoXPreTrainedModel(PreTrainedModel):
|
|||||||
|
|
||||||
|
|
||||||
class GPTNeoXAttention(nn.Module):
|
class GPTNeoXAttention(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config, layer_idx=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.num_attention_heads = config.num_attention_heads
|
self.num_attention_heads = config.num_attention_heads
|
||||||
@ -98,11 +156,18 @@ class GPTNeoXAttention(nn.Module):
|
|||||||
self.register_buffer("masked_bias", torch.tensor(-1e9), persistent=False)
|
self.register_buffer("masked_bias", torch.tensor(-1e9), persistent=False)
|
||||||
self._init_rope()
|
self._init_rope()
|
||||||
|
|
||||||
|
if layer_idx is None:
|
||||||
|
logger.warning_once(
|
||||||
|
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
|
||||||
|
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
|
||||||
|
"when creating this class."
|
||||||
|
)
|
||||||
self.norm_factor = self.head_size**-0.5
|
self.norm_factor = self.head_size**-0.5
|
||||||
self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=config.attention_bias)
|
self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=config.attention_bias)
|
||||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
|
||||||
self.attention_dropout = nn.Dropout(config.attention_dropout)
|
self.attention_dropout = nn.Dropout(config.attention_dropout)
|
||||||
self.is_causal = True
|
self.is_causal = True
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
|
||||||
def _init_bias(self, max_positions, device=None):
|
def _init_bias(self, max_positions, device=None):
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
@ -146,9 +211,11 @@ class GPTNeoXAttention(nn.Module):
|
|||||||
attention_mask: torch.FloatTensor,
|
attention_mask: torch.FloatTensor,
|
||||||
position_ids: torch.LongTensor,
|
position_ids: torch.LongTensor,
|
||||||
head_mask: Optional[torch.FloatTensor] = None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
layer_past: Optional[Cache] = None,
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
|
padding_mask: Optional[torch.Tensor] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
):
|
):
|
||||||
# Apply attention-specific projections and rope
|
# Apply attention-specific projections and rope
|
||||||
query, key, value, present = self._attn_projections_and_rope(
|
query, key, value, present = self._attn_projections_and_rope(
|
||||||
@ -199,9 +266,8 @@ class GPTNeoXAttention(nn.Module):
|
|||||||
position_ids: torch.LongTensor,
|
position_ids: torch.LongTensor,
|
||||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
):
|
):
|
||||||
has_layer_past = layer_past is not None
|
|
||||||
|
|
||||||
# Compute QKV
|
# Compute QKV
|
||||||
# Attention heads [batch, seq_len, hidden_size]
|
# Attention heads [batch, seq_len, hidden_size]
|
||||||
# --> [batch, seq_len, (np * 3 * head_size)]
|
# --> [batch, seq_len, (np * 3 * head_size)]
|
||||||
@ -225,22 +291,31 @@ class GPTNeoXAttention(nn.Module):
|
|||||||
|
|
||||||
# Compute token offset for rotary embeddings (when decoding)
|
# Compute token offset for rotary embeddings (when decoding)
|
||||||
seq_len = key.shape[-2]
|
seq_len = key.shape[-2]
|
||||||
if has_layer_past:
|
if layer_past is not None:
|
||||||
seq_len += layer_past[0].shape[-2]
|
if self.layer_idx is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||||
|
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||||
|
"with a layer index."
|
||||||
|
)
|
||||||
|
seq_len += layer_past.get_seq_length(self.layer_idx)
|
||||||
|
|
||||||
cos, sin = self.rotary_emb(value, seq_len=seq_len)
|
cos, sin = self.rotary_emb(value, seq_len=seq_len)
|
||||||
query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
|
query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
|
||||||
query = torch.cat((query, query_pass), dim=-1)
|
query = torch.cat((query, query_pass), dim=-1)
|
||||||
key = torch.cat((key, key_pass), dim=-1)
|
key = torch.cat((key, key_pass), dim=-1)
|
||||||
|
|
||||||
# Cache QKV values
|
# Cache QKV values
|
||||||
if has_layer_past:
|
if layer_past is not None:
|
||||||
past_key = layer_past[0]
|
cache_kwargs = {
|
||||||
past_value = layer_past[1]
|
"sin": sin,
|
||||||
key = torch.cat((past_key, key), dim=-2)
|
"cos": cos,
|
||||||
value = torch.cat((past_value, value), dim=-2)
|
"partial_rotation_size": self.rotary_emb.dim,
|
||||||
present = (key, value) if use_cache else None
|
"cache_position": cache_position,
|
||||||
|
}
|
||||||
|
key, value = layer_past.update(key, value, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
return query, key, value, present
|
return query, key, value, layer_past
|
||||||
|
|
||||||
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
|
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
|
||||||
# q, k, v: [bs, num_attention_heads, seq_len, attn_head_size]
|
# q, k, v: [bs, num_attention_heads, seq_len, attn_head_size]
|
||||||
@ -277,9 +352,9 @@ class GPTNeoXAttention(nn.Module):
|
|||||||
mask_value = torch.tensor(mask_value, dtype=attn_scores.dtype).to(attn_scores.device)
|
mask_value = torch.tensor(mask_value, dtype=attn_scores.dtype).to(attn_scores.device)
|
||||||
attn_scores = torch.where(causal_mask, attn_scores, mask_value)
|
attn_scores = torch.where(causal_mask, attn_scores, mask_value)
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None: # no matter the length, we just slice it
|
||||||
# Apply the attention mask
|
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||||
attn_scores = attn_scores + attention_mask
|
attn_scores = attn_scores + causal_mask
|
||||||
|
|
||||||
attn_weights = nn.functional.softmax(attn_scores, dim=-1)
|
attn_weights = nn.functional.softmax(attn_scores, dim=-1)
|
||||||
attn_weights = attn_weights.to(value.dtype)
|
attn_weights = attn_weights.to(value.dtype)
|
||||||
@ -316,13 +391,18 @@ class GPTNeoXFlashAttention2(GPTNeoXAttention):
|
|||||||
attention_mask: torch.FloatTensor,
|
attention_mask: torch.FloatTensor,
|
||||||
position_ids: torch.LongTensor,
|
position_ids: torch.LongTensor,
|
||||||
head_mask: Optional[torch.FloatTensor] = None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
layer_past: Optional[Cache] = None,
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
):
|
):
|
||||||
# Apply attention-specific projections and rope
|
# Apply attention-specific projections and rope
|
||||||
query, key, value, present = self._attn_projections_and_rope(
|
query, key, value, present = self._attn_projections_and_rope(
|
||||||
hidden_states=hidden_states, position_ids=position_ids, layer_past=layer_past, use_cache=use_cache
|
hidden_states=hidden_states,
|
||||||
|
position_ids=position_ids,
|
||||||
|
layer_past=layer_past,
|
||||||
|
use_cache=use_cache,
|
||||||
|
cache_position=cache_position,
|
||||||
)
|
)
|
||||||
|
|
||||||
query_length = query.shape[-2]
|
query_length = query.shape[-2]
|
||||||
@ -384,7 +464,7 @@ class GPTNeoXFlashAttention2(GPTNeoXAttention):
|
|||||||
)
|
)
|
||||||
attn_output = self.dense(attn_output)
|
attn_output = self.dense(attn_output)
|
||||||
|
|
||||||
outputs = (attn_output, present)
|
outputs = (attn_output, layer_past)
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
outputs += (attn_weights,)
|
outputs += (attn_weights,)
|
||||||
|
|
||||||
@ -398,8 +478,8 @@ class GPTNeoXSdpaAttention(GPTNeoXAttention):
|
|||||||
to adapt to the SDPA API.
|
to adapt to the SDPA API.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config, layer_idx=None):
|
||||||
super().__init__(config)
|
super().__init__(config, layer_idx=layer_idx)
|
||||||
|
|
||||||
# SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
|
# SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
|
||||||
# attn_mask, so we need to call `.contiguous()`. This was fixed in torch==2.2.0.
|
# attn_mask, so we need to call `.contiguous()`. This was fixed in torch==2.2.0.
|
||||||
@ -415,6 +495,7 @@ class GPTNeoXSdpaAttention(GPTNeoXAttention):
|
|||||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
):
|
):
|
||||||
if output_attentions or head_mask is not None:
|
if output_attentions or head_mask is not None:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
@ -431,15 +512,24 @@ class GPTNeoXSdpaAttention(GPTNeoXAttention):
|
|||||||
layer_past=layer_past,
|
layer_past=layer_past,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
|
cache_position=cache_position,
|
||||||
)
|
)
|
||||||
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
# Apply attention-specific projections and rope
|
# Apply attention-specific projections and rope
|
||||||
query, key, value, present = self._attn_projections_and_rope(
|
query, key, value, present = self._attn_projections_and_rope(
|
||||||
hidden_states=hidden_states, position_ids=position_ids, layer_past=layer_past, use_cache=use_cache
|
hidden_states=hidden_states,
|
||||||
|
position_ids=position_ids,
|
||||||
|
layer_past=layer_past,
|
||||||
|
use_cache=use_cache,
|
||||||
|
cache_position=cache_position,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
causal_mask = attention_mask
|
||||||
|
if attention_mask is not None:
|
||||||
|
causal_mask = causal_mask[:, :, :, : key.shape[-2]]
|
||||||
|
|
||||||
# GPT-neo-X casts query and key in fp32 to apply rotary embedding in full precision
|
# GPT-neo-X casts query and key in fp32 to apply rotary embedding in full precision
|
||||||
target_dtype = value.dtype
|
target_dtype = value.dtype
|
||||||
if query.dtype != target_dtype:
|
if query.dtype != target_dtype:
|
||||||
@ -455,13 +545,13 @@ class GPTNeoXSdpaAttention(GPTNeoXAttention):
|
|||||||
|
|
||||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||||
is_causal = True if attention_mask is None and q_len > 1 else False
|
is_causal = True if causal_mask is None and q_len > 1 else False
|
||||||
|
|
||||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||||
query=query,
|
query=query,
|
||||||
key=key,
|
key=key,
|
||||||
value=value,
|
value=value,
|
||||||
attn_mask=attention_mask,
|
attn_mask=causal_mask,
|
||||||
dropout_p=self.attention_dropout.p if self.training else 0.0,
|
dropout_p=self.attention_dropout.p if self.training else 0.0,
|
||||||
is_causal=is_causal,
|
is_causal=is_causal,
|
||||||
)
|
)
|
||||||
@ -624,14 +714,14 @@ GPT_NEOX_ATTENTION_CLASSES = {
|
|||||||
|
|
||||||
|
|
||||||
class GPTNeoXLayer(nn.Module):
|
class GPTNeoXLayer(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config, layer_idx):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.use_parallel_residual = config.use_parallel_residual
|
self.use_parallel_residual = config.use_parallel_residual
|
||||||
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.post_attention_dropout = nn.Dropout(config.hidden_dropout)
|
self.post_attention_dropout = nn.Dropout(config.hidden_dropout)
|
||||||
self.post_mlp_dropout = nn.Dropout(config.hidden_dropout)
|
self.post_mlp_dropout = nn.Dropout(config.hidden_dropout)
|
||||||
self.attention = GPT_NEOX_ATTENTION_CLASSES[config._attn_implementation](config)
|
self.attention = GPT_NEOX_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
|
||||||
self.mlp = GPTNeoXMLP(config)
|
self.mlp = GPTNeoXMLP(config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -641,8 +731,9 @@ class GPTNeoXLayer(nn.Module):
|
|||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
head_mask: Optional[torch.FloatTensor] = None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
layer_past: Optional[Cache] = None,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
):
|
):
|
||||||
attention_layer_outputs = self.attention(
|
attention_layer_outputs = self.attention(
|
||||||
self.input_layernorm(hidden_states),
|
self.input_layernorm(hidden_states),
|
||||||
@ -652,6 +743,7 @@ class GPTNeoXLayer(nn.Module):
|
|||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
|
cache_position=cache_position,
|
||||||
)
|
)
|
||||||
attn_output = attention_layer_outputs[0] # output_attn: attn_output, present, (attn_weights)
|
attn_output = attention_layer_outputs[0] # output_attn: attn_output, present, (attn_weights)
|
||||||
attn_output = self.post_attention_dropout(attn_output)
|
attn_output = self.post_attention_dropout(attn_output)
|
||||||
@ -722,6 +814,23 @@ GPT_NEOX_INPUTS_DOCSTRING = r"""
|
|||||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
||||||
is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
|
is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
|
||||||
model's internal embedding lookup matrix.
|
model's internal embedding lookup matrix.
|
||||||
|
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
|
||||||
|
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||||
|
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
||||||
|
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
||||||
|
|
||||||
|
Two formats are allowed:
|
||||||
|
- a [`~cache_utils.Cache`] instance;
|
||||||
|
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
||||||
|
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
|
||||||
|
cache format.
|
||||||
|
|
||||||
|
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
|
||||||
|
legacy cache format will be returned.
|
||||||
|
|
||||||
|
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
||||||
|
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
|
||||||
|
of shape `(batch_size, sequence_length)`.
|
||||||
output_attentions (`bool`, *optional*):
|
output_attentions (`bool`, *optional*):
|
||||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||||
tensors for more detail.
|
tensors for more detail.
|
||||||
@ -730,6 +839,10 @@ GPT_NEOX_INPUTS_DOCSTRING = r"""
|
|||||||
more detail.
|
more detail.
|
||||||
return_dict (`bool`, *optional*):
|
return_dict (`bool`, *optional*):
|
||||||
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
|
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
|
||||||
|
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||||
|
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
|
||||||
|
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
|
||||||
|
the complete sequence length.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -744,7 +857,7 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
|
|||||||
|
|
||||||
self.embed_in = nn.Embedding(config.vocab_size, config.hidden_size)
|
self.embed_in = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||||
self.emb_dropout = nn.Dropout(config.hidden_dropout)
|
self.emb_dropout = nn.Dropout(config.hidden_dropout)
|
||||||
self.layers = nn.ModuleList([GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)])
|
self.layers = nn.ModuleList([GPTNeoXLayer(config, i) for i in range(config.num_hidden_layers)])
|
||||||
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
|
||||||
self._attn_implementation = config._attn_implementation
|
self._attn_implementation = config._attn_implementation
|
||||||
@ -774,18 +887,14 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
|
|||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
head_mask: Optional[torch.FloatTensor] = None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.FloatTensor]]]] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
r"""
|
r"""
|
||||||
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
|
||||||
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
|
||||||
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
|
||||||
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
|
||||||
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
|
||||||
use_cache (`bool`, *optional*):
|
use_cache (`bool`, *optional*):
|
||||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
||||||
`past_key_values`).
|
`past_key_values`).
|
||||||
@ -797,50 +906,42 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
|
|||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
|
||||||
if input_ids is not None and inputs_embeds is not None:
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
raise ValueError(
|
||||||
elif input_ids is not None:
|
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
||||||
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
)
|
||||||
input_shape = input_ids.size()
|
|
||||||
elif inputs_embeds is not None:
|
|
||||||
input_shape = inputs_embeds.size()[:-1]
|
|
||||||
else:
|
|
||||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
|
||||||
|
|
||||||
batch_size, seq_length = input_shape
|
if self.gradient_checkpointing and self.training:
|
||||||
|
if use_cache:
|
||||||
if past_key_values is None:
|
logger.warning_once(
|
||||||
past_length = 0
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||||
past_key_values = tuple([None] * self.config.num_hidden_layers)
|
)
|
||||||
else:
|
use_cache = False
|
||||||
past_length = past_key_values[0][0].size(-2)
|
|
||||||
|
|
||||||
if position_ids is None:
|
|
||||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
|
||||||
position_ids = torch.arange(past_length, seq_length + past_length, dtype=torch.long, device=device)
|
|
||||||
position_ids = position_ids.unsqueeze(0)
|
|
||||||
|
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embed_in(input_ids)
|
inputs_embeds = self.embed_in(input_ids)
|
||||||
|
|
||||||
# Attention mask.
|
use_legacy_cache = False
|
||||||
attention_mask = attention_mask.view(batch_size, -1) if attention_mask is not None else None
|
if use_cache and not isinstance(past_key_values, Cache):
|
||||||
if self._attn_implementation == "flash_attention_2":
|
use_legacy_cache = True
|
||||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
elif self._attn_implementation == "sdpa" and not output_attentions and head_mask is None:
|
if not self.training:
|
||||||
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
logger.warning_once(
|
||||||
attention_mask=attention_mask,
|
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.45. "
|
||||||
input_shape=(batch_size, seq_length),
|
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
|
||||||
inputs_embeds=inputs_embeds,
|
)
|
||||||
past_key_values_length=past_length,
|
|
||||||
)
|
seq_length = inputs_embeds.shape[1]
|
||||||
else:
|
if cache_position is None:
|
||||||
attention_mask = _prepare_4d_causal_attention_mask(
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
attention_mask=attention_mask,
|
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=inputs_embeds.device)
|
||||||
input_shape=(batch_size, seq_length),
|
|
||||||
inputs_embeds=inputs_embeds,
|
if position_ids is None:
|
||||||
past_key_values_length=past_length,
|
position_ids = cache_position.unsqueeze(0)
|
||||||
)
|
|
||||||
|
causal_mask = self._update_causal_mask(
|
||||||
|
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||||
|
)
|
||||||
|
|
||||||
# Prepare head mask if needed
|
# Prepare head mask if needed
|
||||||
# 1.0 in head_mask indicate we keep the head
|
# 1.0 in head_mask indicate we keep the head
|
||||||
@ -848,20 +949,14 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
|
|||||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||||
|
|
||||||
hidden_states = self.emb_dropout(inputs_embeds)
|
hidden_states = self.emb_dropout(inputs_embeds)
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
next_decoder_cache = None
|
||||||
if use_cache:
|
|
||||||
logger.warning(
|
|
||||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
|
||||||
)
|
|
||||||
use_cache = False
|
|
||||||
|
|
||||||
presents = () if use_cache else None
|
|
||||||
all_attentions = () if output_attentions else None
|
all_attentions = () if output_attentions else None
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
for i, (layer, layer_past) in enumerate(zip(self.layers, past_key_values)):
|
for i, layer in enumerate(
|
||||||
|
self.layers,
|
||||||
|
):
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
@ -869,26 +964,28 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
|
|||||||
outputs = self._gradient_checkpointing_func(
|
outputs = self._gradient_checkpointing_func(
|
||||||
layer.__call__,
|
layer.__call__,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
causal_mask,
|
||||||
position_ids,
|
position_ids,
|
||||||
head_mask[i],
|
head_mask[i],
|
||||||
use_cache,
|
use_cache,
|
||||||
None,
|
None,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
|
cache_position,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
outputs = layer(
|
outputs = layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=causal_mask,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
head_mask=head_mask[i],
|
head_mask=head_mask[i],
|
||||||
layer_past=layer_past,
|
layer_past=past_key_values,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
|
cache_position=cache_position,
|
||||||
)
|
)
|
||||||
hidden_states = outputs[0]
|
hidden_states = outputs[0]
|
||||||
if use_cache is True:
|
if use_cache is True:
|
||||||
presents = presents + (outputs[1],)
|
next_decoder_cache = outputs[1]
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
all_attentions = all_attentions + (outputs[2 if use_cache else 1],)
|
all_attentions = all_attentions + (outputs[2 if use_cache else 1],)
|
||||||
|
|
||||||
@ -897,16 +994,92 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
next_cache = None
|
||||||
|
if use_cache:
|
||||||
|
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None)
|
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_attentions] if v is not None)
|
||||||
|
|
||||||
return BaseModelOutputWithPast(
|
return BaseModelOutputWithPast(
|
||||||
last_hidden_state=hidden_states,
|
last_hidden_state=hidden_states,
|
||||||
past_key_values=presents,
|
past_key_values=next_cache,
|
||||||
hidden_states=all_hidden_states,
|
hidden_states=all_hidden_states,
|
||||||
attentions=all_attentions,
|
attentions=all_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
|
||||||
|
def _update_causal_mask(
|
||||||
|
self,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
input_tensor: torch.Tensor,
|
||||||
|
cache_position: torch.Tensor,
|
||||||
|
past_key_values: Cache,
|
||||||
|
output_attentions: bool,
|
||||||
|
):
|
||||||
|
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
|
||||||
|
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
|
||||||
|
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
|
||||||
|
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
|
||||||
|
|
||||||
|
if self.config._attn_implementation == "flash_attention_2":
|
||||||
|
if attention_mask is not None and 0.0 in attention_mask:
|
||||||
|
return attention_mask
|
||||||
|
return None
|
||||||
|
|
||||||
|
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||||
|
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||||
|
# to infer the attention mask.
|
||||||
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
|
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||||
|
|
||||||
|
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||||
|
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
|
||||||
|
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||||
|
attention_mask,
|
||||||
|
inputs_embeds=input_tensor,
|
||||||
|
past_key_values_length=past_seen_tokens,
|
||||||
|
is_training=self.training,
|
||||||
|
):
|
||||||
|
return None
|
||||||
|
|
||||||
|
dtype, device = input_tensor.dtype, input_tensor.device
|
||||||
|
min_dtype = torch.finfo(dtype).min
|
||||||
|
sequence_length = input_tensor.shape[1]
|
||||||
|
if using_static_cache:
|
||||||
|
target_length = past_key_values.get_max_length()
|
||||||
|
else:
|
||||||
|
target_length = (
|
||||||
|
attention_mask.shape[-1]
|
||||||
|
if isinstance(attention_mask, torch.Tensor)
|
||||||
|
else past_seen_tokens + sequence_length + 1
|
||||||
|
)
|
||||||
|
|
||||||
|
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||||
|
causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
|
attention_mask,
|
||||||
|
sequence_length=sequence_length,
|
||||||
|
target_length=target_length,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
min_dtype=min_dtype,
|
||||||
|
cache_position=cache_position,
|
||||||
|
batch_size=input_tensor.shape[0],
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.config._attn_implementation == "sdpa"
|
||||||
|
and attention_mask is not None
|
||||||
|
and attention_mask.device.type == "cuda"
|
||||||
|
and not output_attentions
|
||||||
|
):
|
||||||
|
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||||
|
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||||
|
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||||
|
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||||
|
|
||||||
|
return causal_mask
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"""GPTNeoX Model with a `language modeling` head on top for CLM fine-tuning.""", GPT_NEOX_START_DOCSTRING
|
"""GPTNeoX Model with a `language modeling` head on top for CLM fine-tuning.""", GPT_NEOX_START_DOCSTRING
|
||||||
@ -938,26 +1111,15 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel):
|
|||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
head_mask: Optional[torch.FloatTensor] = None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.FloatTensor]]]] = None,
|
||||||
labels: Optional[torch.LongTensor] = None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
r"""
|
r"""
|
||||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
|
||||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
|
||||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
|
||||||
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional tensors are
|
|
||||||
only required when the model is used as a decoder in a Sequence to Sequence model.
|
|
||||||
|
|
||||||
Contains pre-computed hidden-states (key and values in the self-attention blocks that can be used (see
|
|
||||||
`past_key_values` input) to speed up sequential decoding.
|
|
||||||
|
|
||||||
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
|
||||||
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
|
||||||
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
|
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
|
||||||
`[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
|
`[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
|
||||||
@ -997,6 +1159,7 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel):
|
|||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
hidden_states = outputs[0]
|
||||||
@ -1024,24 +1187,27 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel):
|
|||||||
attentions=outputs.attentions,
|
attentions=outputs.attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# can't be copied from llama, gpt-neox has emebd_out and not lm_head
|
||||||
def prepare_inputs_for_generation(
|
def prepare_inputs_for_generation(
|
||||||
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
self,
|
||||||
|
input_ids,
|
||||||
|
past_key_values=None,
|
||||||
|
attention_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
cache_position=None,
|
||||||
|
position_ids=None,
|
||||||
|
use_cache=True,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
input_shape = input_ids.shape
|
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||||
# cut decoder_input_ids if past is used
|
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
||||||
|
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||||
if past_key_values is not None:
|
if past_key_values is not None:
|
||||||
past_length = past_key_values[0][0].shape[2]
|
if inputs_embeds is not None: # Exception 1
|
||||||
|
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||||
|
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||||
|
input_ids = input_ids[:, cache_position]
|
||||||
|
|
||||||
# Some generation methods already pass only the last input ID
|
|
||||||
if input_ids.shape[1] > past_length:
|
|
||||||
remove_prefix_length = past_length
|
|
||||||
else:
|
|
||||||
# Default to old behavior: keep only final ID
|
|
||||||
remove_prefix_length = input_ids.shape[1] - 1
|
|
||||||
|
|
||||||
input_ids = input_ids[:, remove_prefix_length:]
|
|
||||||
|
|
||||||
position_ids = kwargs.get("position_ids", None)
|
|
||||||
if attention_mask is not None and position_ids is None:
|
if attention_mask is not None and position_ids is None:
|
||||||
# create position_ids on the fly for batch generation
|
# create position_ids on the fly for batch generation
|
||||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||||
@ -1049,24 +1215,46 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel):
|
|||||||
if past_key_values:
|
if past_key_values:
|
||||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||||
|
|
||||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
|
||||||
if attention_mask is None:
|
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
|
||||||
attention_mask = input_ids.new_ones(input_shape)
|
|
||||||
|
|
||||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||||
if inputs_embeds is not None and past_key_values is None:
|
if inputs_embeds is not None and cache_position[0] == 0:
|
||||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||||
else:
|
else:
|
||||||
model_inputs = {"input_ids": input_ids}
|
model_inputs = {"input_ids": input_ids}
|
||||||
|
|
||||||
|
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
|
||||||
|
if inputs_embeds is not None:
|
||||||
|
batch_size, sequence_length = inputs_embeds.shape
|
||||||
|
device = inputs_embeds.device
|
||||||
|
else:
|
||||||
|
batch_size, sequence_length = input_ids.shape
|
||||||
|
device = input_ids.device
|
||||||
|
|
||||||
|
dtype = self.embed_out.weight.dtype
|
||||||
|
min_dtype = torch.finfo(dtype).min
|
||||||
|
|
||||||
|
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
|
attention_mask,
|
||||||
|
sequence_length=sequence_length,
|
||||||
|
target_length=past_key_values.get_max_length(),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
min_dtype=min_dtype,
|
||||||
|
cache_position=cache_position,
|
||||||
|
batch_size=batch_size,
|
||||||
|
)
|
||||||
|
|
||||||
model_inputs.update(
|
model_inputs.update(
|
||||||
{
|
{
|
||||||
"attention_mask": attention_mask,
|
|
||||||
"past_key_values": past_key_values,
|
|
||||||
"position_ids": position_ids,
|
"position_ids": position_ids,
|
||||||
"use_cache": kwargs.get("use_cache"),
|
"cache_position": cache_position,
|
||||||
|
"past_key_values": past_key_values,
|
||||||
|
"use_cache": use_cache,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
def _reorder_cache(self, past_key_values, beam_idx):
|
def _reorder_cache(self, past_key_values, beam_idx):
|
||||||
@ -1117,7 +1305,7 @@ class GPTNeoXForSequenceClassification(GPTNeoXPreTrainedModel):
|
|||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
head_mask: Optional[torch.FloatTensor] = None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.FloatTensor]]]] = None,
|
||||||
labels: Optional[torch.LongTensor] = None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
@ -1229,7 +1417,7 @@ class GPTNeoXForTokenClassification(GPTNeoXPreTrainedModel):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor]]]] = None,
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
token_type_ids: Optional[torch.LongTensor] = None,
|
token_type_ids: Optional[torch.LongTensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
@ -24,6 +24,8 @@ from torch import nn
|
|||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
|
from ...cache_utils import Cache, DynamicCache, StaticCache
|
||||||
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutputWithPast,
|
BaseModelOutputWithPast,
|
||||||
CausalLMOutputWithPast,
|
CausalLMOutputWithPast,
|
||||||
@ -55,6 +57,60 @@ _REAL_CHECKPOINT_FOR_DOC = "EleutherAI/gpt-j-6B"
|
|||||||
_CONFIG_FOR_DOC = "GPTJConfig"
|
_CONFIG_FOR_DOC = "GPTJConfig"
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
|
||||||
|
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
sequence_length: int,
|
||||||
|
target_length: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
min_dtype: float,
|
||||||
|
cache_position: torch.Tensor,
|
||||||
|
batch_size: int,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||||
|
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
attention_mask (`torch.Tensor`):
|
||||||
|
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
||||||
|
sequence_length (`int`):
|
||||||
|
The sequence length being processed.
|
||||||
|
target_length (`int`):
|
||||||
|
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
|
||||||
|
dtype (`torch.dtype`):
|
||||||
|
The dtype to use for the 4D attention mask.
|
||||||
|
device (`torch.device`):
|
||||||
|
The device to plcae the 4D attention mask on.
|
||||||
|
min_dtype (`float`):
|
||||||
|
The minimum value representable with the dtype `dtype`.
|
||||||
|
cache_position (`torch.Tensor`):
|
||||||
|
Indices depicting the position of the input sequence tokens in the sequence.
|
||||||
|
batch_size (`torch.Tensor`):
|
||||||
|
Batch size.
|
||||||
|
"""
|
||||||
|
if attention_mask is not None and attention_mask.dim() == 4:
|
||||||
|
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||||
|
causal_mask = attention_mask
|
||||||
|
else:
|
||||||
|
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
||||||
|
if sequence_length != 1:
|
||||||
|
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||||
|
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||||
|
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||||
|
if attention_mask is not None:
|
||||||
|
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||||
|
mask_length = attention_mask.shape[-1]
|
||||||
|
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
||||||
|
padding_mask = padding_mask == 0
|
||||||
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||||
|
padding_mask, min_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
return causal_mask
|
||||||
|
|
||||||
|
|
||||||
def create_sinusoidal_positions(num_pos: int, dim: int) -> torch.Tensor:
|
def create_sinusoidal_positions(num_pos: int, dim: int) -> torch.Tensor:
|
||||||
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64) / dim))
|
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64) / dim))
|
||||||
sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=torch.int64).float(), inv_freq).float()
|
sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=torch.int64).float(), inv_freq).float()
|
||||||
@ -80,23 +136,22 @@ def apply_rotary_pos_emb(tensor: torch.Tensor, sin: torch.Tensor, cos: torch.Ten
|
|||||||
|
|
||||||
|
|
||||||
class GPTJAttention(nn.Module):
|
class GPTJAttention(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config, layer_idx=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
max_positions = config.max_position_embeddings
|
max_positions = config.max_position_embeddings
|
||||||
self.register_buffer(
|
|
||||||
"bias",
|
|
||||||
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
|
|
||||||
1, 1, max_positions, max_positions
|
|
||||||
),
|
|
||||||
persistent=False,
|
|
||||||
)
|
|
||||||
self.register_buffer("masked_bias", torch.tensor(-1e9), persistent=False)
|
|
||||||
|
|
||||||
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
||||||
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
||||||
|
|
||||||
self.is_causal = True
|
self.is_causal = True
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
if layer_idx is None:
|
||||||
|
logger.warning_once(
|
||||||
|
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
|
||||||
|
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
|
||||||
|
"when creating this class."
|
||||||
|
)
|
||||||
|
|
||||||
self.embed_dim = config.hidden_size
|
self.embed_dim = config.hidden_size
|
||||||
self.num_attention_heads = config.num_attention_heads
|
self.num_attention_heads = config.num_attention_heads
|
||||||
@ -152,27 +207,16 @@ class GPTJAttention(nn.Module):
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
):
|
):
|
||||||
# compute causal mask from causal mask buffer
|
|
||||||
query_length, key_length = query.size(-2), key.size(-2)
|
|
||||||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
|
|
||||||
|
|
||||||
# Keep the attention weights computation in fp32 to avoid overflow issues
|
# Keep the attention weights computation in fp32 to avoid overflow issues
|
||||||
query = query.to(torch.float32)
|
query = query.to(torch.float32)
|
||||||
key = key.to(torch.float32)
|
key = key.to(torch.float32)
|
||||||
|
|
||||||
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
||||||
|
|
||||||
mask_value = torch.finfo(attn_weights.dtype).min
|
|
||||||
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
|
||||||
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
|
|
||||||
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
|
|
||||||
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
|
|
||||||
|
|
||||||
attn_weights = attn_weights / self.scale_attn
|
attn_weights = attn_weights / self.scale_attn
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None: # no matter the length, we just slice it
|
||||||
# Apply the attention mask
|
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||||
attn_weights = attn_weights + attention_mask
|
attn_weights = attn_weights + causal_mask
|
||||||
|
|
||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||||
attn_weights = attn_weights.to(value.dtype)
|
attn_weights = attn_weights.to(value.dtype)
|
||||||
@ -196,12 +240,13 @@ class GPTJAttention(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.FloatTensor,
|
hidden_states: torch.FloatTensor,
|
||||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
layer_past: Optional[Cache] = None,
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
head_mask: Optional[torch.FloatTensor] = None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
) -> Union[
|
) -> Union[
|
||||||
Tuple[torch.Tensor, Tuple[torch.Tensor]],
|
Tuple[torch.Tensor, Tuple[torch.Tensor]],
|
||||||
Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]],
|
Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]],
|
||||||
@ -245,17 +290,13 @@ class GPTJAttention(nn.Module):
|
|||||||
query = query.permute(0, 2, 1, 3)
|
query = query.permute(0, 2, 1, 3)
|
||||||
|
|
||||||
if layer_past is not None:
|
if layer_past is not None:
|
||||||
past_key = layer_past[0]
|
cache_kwargs = {
|
||||||
past_value = layer_past[1]
|
"sin": sin,
|
||||||
key = torch.cat((past_key, key), dim=-2)
|
"cos": cos,
|
||||||
value = torch.cat((past_value, value), dim=-2)
|
"partial_rotation_size": self.rotary_dim,
|
||||||
|
"cache_position": cache_position,
|
||||||
if use_cache is True:
|
}
|
||||||
# Note that this cast is quite ugly, but is not implemented before ROPE as the original codebase keeps the key in float32 all along the computation.
|
key, value = layer_past.update(key, value, self.layer_idx, cache_kwargs)
|
||||||
# Reference: https://github.com/kingoflolz/mesh-transformer-jax/blob/f8315e3003033b23f21d78361b288953064e0e76/mesh_transformer/layers.py#L128
|
|
||||||
present = (key.to(hidden_states.dtype), value)
|
|
||||||
else:
|
|
||||||
present = None
|
|
||||||
|
|
||||||
# compute self-attention: V x Softmax(QK^T)
|
# compute self-attention: V x Softmax(QK^T)
|
||||||
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
||||||
@ -264,7 +305,7 @@ class GPTJAttention(nn.Module):
|
|||||||
attn_output = self.out_proj(attn_output)
|
attn_output = self.out_proj(attn_output)
|
||||||
attn_output = self.resid_dropout(attn_output)
|
attn_output = self.resid_dropout(attn_output)
|
||||||
|
|
||||||
outputs = (attn_output, present)
|
outputs = (attn_output, layer_past)
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
outputs += (attn_weights,)
|
outputs += (attn_weights,)
|
||||||
|
|
||||||
@ -290,12 +331,13 @@ class GPTJFlashAttention2(GPTJAttention):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.FloatTensor,
|
hidden_states: torch.FloatTensor,
|
||||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
layer_past: Optional[Cache] = None,
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
head_mask: Optional[torch.FloatTensor] = None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
) -> Union[
|
) -> Union[
|
||||||
Tuple[torch.Tensor, Tuple[torch.Tensor]],
|
Tuple[torch.Tensor, Tuple[torch.Tensor]],
|
||||||
Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]],
|
Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]],
|
||||||
@ -343,17 +385,13 @@ class GPTJFlashAttention2(GPTJAttention):
|
|||||||
# value: batch_size x num_attention_heads x seq_length x head_dim
|
# value: batch_size x num_attention_heads x seq_length x head_dim
|
||||||
|
|
||||||
if layer_past is not None:
|
if layer_past is not None:
|
||||||
past_key = layer_past[0]
|
cache_kwargs = {
|
||||||
past_value = layer_past[1]
|
"sin": sin,
|
||||||
key = torch.cat((past_key, key), dim=-2)
|
"cos": cos,
|
||||||
value = torch.cat((past_value, value), dim=-2)
|
"partial_rotation_size": self.rotary_dim,
|
||||||
|
"cache_position": cache_position,
|
||||||
if use_cache is True:
|
}
|
||||||
# Note that this cast is quite ugly, but is not implemented before ROPE as the original codebase keeps the key in float32 all along the computation.
|
key, value = layer_past.update(key, value, self.layer_idx, cache_kwargs)
|
||||||
# Reference: https://github.com/kingoflolz/mesh-transformer-jax/blob/f8315e3003033b23f21d78361b288953064e0e76/mesh_transformer/layers.py#L128
|
|
||||||
present = (key.to(hidden_states.dtype), value)
|
|
||||||
else:
|
|
||||||
present = None
|
|
||||||
|
|
||||||
# The Flash attention requires the input to have the shape
|
# The Flash attention requires the input to have the shape
|
||||||
# batch_size x seq_length x head_dim x hidden_dim
|
# batch_size x seq_length x head_dim x hidden_dim
|
||||||
@ -412,7 +450,7 @@ class GPTJFlashAttention2(GPTJAttention):
|
|||||||
attn_output = self.out_proj(attn_output)
|
attn_output = self.out_proj(attn_output)
|
||||||
attn_output = self.resid_dropout(attn_output)
|
attn_output = self.resid_dropout(attn_output)
|
||||||
|
|
||||||
outputs = (attn_output, present)
|
outputs = (attn_output, layer_past)
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
outputs += (attn_weights,)
|
outputs += (attn_weights,)
|
||||||
|
|
||||||
@ -445,22 +483,23 @@ class GPTJMLP(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class GPTJBlock(nn.Module):
|
class GPTJBlock(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config, layer_idx=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd
|
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd
|
||||||
self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
||||||
self.attn = GPTJ_ATTENTION_CLASSES[config._attn_implementation](config)
|
self.attn = GPTJ_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
|
||||||
self.mlp = GPTJMLP(inner_dim, config)
|
self.mlp = GPTJMLP(inner_dim, config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: Optional[torch.FloatTensor],
|
hidden_states: Optional[torch.FloatTensor],
|
||||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
layer_past: Optional[Cache] = None,
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
head_mask: Optional[torch.FloatTensor] = None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
|
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.ln_1(hidden_states)
|
hidden_states = self.ln_1(hidden_states)
|
||||||
@ -472,6 +511,7 @@ class GPTJBlock(nn.Module):
|
|||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
|
cache_position=cache_position,
|
||||||
)
|
)
|
||||||
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
|
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
|
||||||
outputs = attn_outputs[1:]
|
outputs = attn_outputs[1:]
|
||||||
@ -500,6 +540,9 @@ class GPTJPreTrainedModel(PreTrainedModel):
|
|||||||
_no_split_modules = ["GPTJBlock"]
|
_no_split_modules = ["GPTJBlock"]
|
||||||
_skip_keys_device_placement = "past_key_values"
|
_skip_keys_device_placement = "past_key_values"
|
||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
|
_supports_cache_class = True
|
||||||
|
_supports_quantized_cache = True
|
||||||
|
_supports_static_cache = True
|
||||||
_supports_param_buffer_assignment = False
|
_supports_param_buffer_assignment = False
|
||||||
|
|
||||||
def __init__(self, *inputs, **kwargs):
|
def __init__(self, *inputs, **kwargs):
|
||||||
@ -572,6 +615,23 @@ GPTJ_INPUTS_DOCSTRING = r"""
|
|||||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
||||||
is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
|
is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
|
||||||
model's internal embedding lookup matrix.
|
model's internal embedding lookup matrix.
|
||||||
|
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
|
||||||
|
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||||
|
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
||||||
|
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
||||||
|
|
||||||
|
Two formats are allowed:
|
||||||
|
- a [`~cache_utils.Cache`] instance;
|
||||||
|
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
||||||
|
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
|
||||||
|
cache format.
|
||||||
|
|
||||||
|
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
|
||||||
|
legacy cache format will be returned.
|
||||||
|
|
||||||
|
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
||||||
|
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
|
||||||
|
of shape `(batch_size, sequence_length)`.
|
||||||
output_attentions (`bool`, *optional*):
|
output_attentions (`bool`, *optional*):
|
||||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||||
tensors for more detail.
|
tensors for more detail.
|
||||||
@ -580,6 +640,10 @@ GPTJ_INPUTS_DOCSTRING = r"""
|
|||||||
more detail.
|
more detail.
|
||||||
return_dict (`bool`, *optional*):
|
return_dict (`bool`, *optional*):
|
||||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
|
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||||
|
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
|
||||||
|
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
|
||||||
|
the complete sequence length.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
PARALLELIZE_DOCSTRING = r"""
|
PARALLELIZE_DOCSTRING = r"""
|
||||||
@ -643,7 +707,7 @@ class GPTJModel(GPTJPreTrainedModel):
|
|||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
|
self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
|
||||||
self.drop = nn.Dropout(config.embd_pdrop)
|
self.drop = nn.Dropout(config.embd_pdrop)
|
||||||
self.h = nn.ModuleList([GPTJBlock(config) for _ in range(config.n_layer)])
|
self.h = nn.ModuleList([GPTJBlock(config, layer_idx=i) for i in range(config.n_layer)])
|
||||||
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||||
|
|
||||||
# Model parallel
|
# Model parallel
|
||||||
@ -714,7 +778,7 @@ class GPTJModel(GPTJPreTrainedModel):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor]]]] = None,
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
token_type_ids: Optional[torch.LongTensor] = None,
|
token_type_ids: Optional[torch.LongTensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
@ -724,6 +788,7 @@ class GPTJModel(GPTJPreTrainedModel):
|
|||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
@ -732,73 +797,10 @@ class GPTJModel(GPTJPreTrainedModel):
|
|||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
if input_ids is not None and inputs_embeds is not None:
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
raise ValueError(
|
||||||
elif input_ids is not None:
|
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
||||||
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
)
|
||||||
input_shape = input_ids.size()
|
|
||||||
input_ids = input_ids.view(-1, input_shape[-1])
|
|
||||||
batch_size = input_ids.shape[0]
|
|
||||||
elif inputs_embeds is not None:
|
|
||||||
input_shape = inputs_embeds.size()[:-1]
|
|
||||||
batch_size = inputs_embeds.shape[0]
|
|
||||||
else:
|
|
||||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
|
||||||
|
|
||||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
|
||||||
|
|
||||||
if token_type_ids is not None:
|
|
||||||
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
|
||||||
|
|
||||||
if past_key_values is None:
|
|
||||||
past_length = 0
|
|
||||||
past_key_values = tuple([None] * len(self.h))
|
|
||||||
else:
|
|
||||||
past_length = past_key_values[0][0].size(-2)
|
|
||||||
|
|
||||||
if position_ids is None:
|
|
||||||
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
|
||||||
position_ids = position_ids.unsqueeze(0)
|
|
||||||
|
|
||||||
if not self._use_flash_attention_2:
|
|
||||||
# Attention mask.
|
|
||||||
if attention_mask is not None:
|
|
||||||
if batch_size <= 0:
|
|
||||||
raise ValueError("batch_size has to be defined and > 0")
|
|
||||||
attention_mask = attention_mask.view(batch_size, -1)
|
|
||||||
# We create a 3D attention mask from a 2D tensor mask.
|
|
||||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
|
||||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
|
||||||
# this attention mask is more simple than the triangular masking of causal attention
|
|
||||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
|
||||||
attention_mask = attention_mask[:, None, None, :]
|
|
||||||
|
|
||||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
|
||||||
# masked positions, this operation will create a tensor which is 0.0 for
|
|
||||||
# positions we want to attend and the dtype's smallest value for masked positions.
|
|
||||||
# Since we are adding it to the raw scores before the softmax, this is
|
|
||||||
# effectively the same as removing these entirely.
|
|
||||||
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
|
||||||
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
|
|
||||||
|
|
||||||
# Prepare head mask if needed
|
|
||||||
# 1.0 in head_mask indicate we keep the head
|
|
||||||
# attention_probs has shape bsz x num_attention_heads x N x N
|
|
||||||
# head_mask has shape n_layer x batch x num_attention_heads x N x N
|
|
||||||
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
|
||||||
|
|
||||||
if inputs_embeds is None:
|
|
||||||
inputs_embeds = self.wte(input_ids)
|
|
||||||
|
|
||||||
hidden_states = inputs_embeds
|
|
||||||
|
|
||||||
if token_type_ids is not None:
|
|
||||||
token_type_embeds = self.wte(token_type_ids)
|
|
||||||
hidden_states = hidden_states + token_type_embeds
|
|
||||||
|
|
||||||
hidden_states = self.drop(hidden_states)
|
|
||||||
|
|
||||||
output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
|
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
if use_cache:
|
if use_cache:
|
||||||
@ -807,19 +809,64 @@ class GPTJModel(GPTJPreTrainedModel):
|
|||||||
)
|
)
|
||||||
use_cache = False
|
use_cache = False
|
||||||
|
|
||||||
presents = () if use_cache else None
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.wte(input_ids)
|
||||||
|
|
||||||
|
use_legacy_cache = False
|
||||||
|
if use_cache and not isinstance(past_key_values, Cache):
|
||||||
|
use_legacy_cache = True
|
||||||
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
|
if not self.training:
|
||||||
|
logger.warning_once(
|
||||||
|
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.45. "
|
||||||
|
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
|
||||||
|
)
|
||||||
|
|
||||||
|
seq_length = inputs_embeds.shape[1]
|
||||||
|
if cache_position is None:
|
||||||
|
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
|
cache_position = torch.arange(
|
||||||
|
past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
|
||||||
|
)
|
||||||
|
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = cache_position.unsqueeze(0)
|
||||||
|
|
||||||
|
causal_mask = self._update_causal_mask(
|
||||||
|
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare head mask if needed
|
||||||
|
# 1.0 in head_mask indicate we keep the head
|
||||||
|
# attention_probs has shape bsz x num_attention_heads x N x N
|
||||||
|
# head_mask has shape n_layer x batch x num_attention_heads x N x N
|
||||||
|
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
|
if token_type_ids is not None:
|
||||||
|
token_type_ids = token_type_ids.view(-1, seq_length)
|
||||||
|
token_type_embeds = self.wte(token_type_ids)
|
||||||
|
hidden_states = hidden_states + token_type_embeds
|
||||||
|
|
||||||
|
hidden_states = self.drop(hidden_states)
|
||||||
|
output_shape = (-1, seq_length, hidden_states.size(-1))
|
||||||
|
|
||||||
|
next_decoder_cache = None
|
||||||
all_self_attentions = () if output_attentions else None
|
all_self_attentions = () if output_attentions else None
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
for i, block in enumerate(self.h):
|
||||||
# Model parallel
|
# Model parallel
|
||||||
if self.model_parallel:
|
if self.model_parallel:
|
||||||
torch.cuda.set_device(hidden_states.device)
|
torch.cuda.set_device(hidden_states.device)
|
||||||
|
|
||||||
# Ensure layer_past is on same device as hidden_states (might not be correct)
|
# Ensure layer_past is on same device as hidden_states (might not be correct)
|
||||||
if layer_past is not None:
|
if past_key_values is not None:
|
||||||
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
|
past_key_values.key_cache = past_key_values.key_cache.to(hidden_states.device)
|
||||||
|
past_key_values.value_cache = past_key_values.value_cache.to(hidden_states.device)
|
||||||
|
|
||||||
# Ensure that attention_mask is always on the same device as hidden_states
|
# Ensure that attention_mask is always on the same device as hidden_states
|
||||||
if attention_mask is not None:
|
if causal_mask is not None:
|
||||||
attention_mask = attention_mask.to(hidden_states.device)
|
causal_mask = causal_mask.to(hidden_states.device)
|
||||||
if isinstance(head_mask, torch.Tensor):
|
if isinstance(head_mask, torch.Tensor):
|
||||||
head_mask = head_mask.to(hidden_states.device)
|
head_mask = head_mask.to(hidden_states.device)
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
@ -830,26 +877,28 @@ class GPTJModel(GPTJPreTrainedModel):
|
|||||||
block.__call__,
|
block.__call__,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
None,
|
None,
|
||||||
attention_mask,
|
causal_mask,
|
||||||
position_ids,
|
position_ids,
|
||||||
head_mask[i],
|
head_mask[i],
|
||||||
use_cache,
|
use_cache,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
|
cache_position,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
outputs = block(
|
outputs = block(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
layer_past=layer_past,
|
layer_past=past_key_values,
|
||||||
attention_mask=attention_mask,
|
attention_mask=causal_mask,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
head_mask=head_mask[i],
|
head_mask=head_mask[i],
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
|
cache_position=cache_position,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
hidden_states = outputs[0]
|
||||||
if use_cache is True:
|
if use_cache is True:
|
||||||
presents = presents + (outputs[1],)
|
next_decoder_cache = outputs[1]
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
||||||
@ -867,16 +916,94 @@ class GPTJModel(GPTJPreTrainedModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
next_cache = None
|
||||||
|
if use_cache:
|
||||||
|
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
return tuple(
|
||||||
|
v for v in [hidden_states, next_cache, all_hidden_states, all_self_attentions] if v is not None
|
||||||
|
)
|
||||||
|
|
||||||
return BaseModelOutputWithPast(
|
return BaseModelOutputWithPast(
|
||||||
last_hidden_state=hidden_states,
|
last_hidden_state=hidden_states,
|
||||||
past_key_values=presents,
|
past_key_values=next_cache,
|
||||||
hidden_states=all_hidden_states,
|
hidden_states=all_hidden_states,
|
||||||
attentions=all_self_attentions,
|
attentions=all_self_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
|
||||||
|
def _update_causal_mask(
|
||||||
|
self,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
input_tensor: torch.Tensor,
|
||||||
|
cache_position: torch.Tensor,
|
||||||
|
past_key_values: Cache,
|
||||||
|
output_attentions: bool,
|
||||||
|
):
|
||||||
|
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
|
||||||
|
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
|
||||||
|
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
|
||||||
|
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
|
||||||
|
|
||||||
|
if self.config._attn_implementation == "flash_attention_2":
|
||||||
|
if attention_mask is not None and 0.0 in attention_mask:
|
||||||
|
return attention_mask
|
||||||
|
return None
|
||||||
|
|
||||||
|
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||||
|
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||||
|
# to infer the attention mask.
|
||||||
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
|
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||||
|
|
||||||
|
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||||
|
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
|
||||||
|
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||||
|
attention_mask,
|
||||||
|
inputs_embeds=input_tensor,
|
||||||
|
past_key_values_length=past_seen_tokens,
|
||||||
|
is_training=self.training,
|
||||||
|
):
|
||||||
|
return None
|
||||||
|
|
||||||
|
dtype, device = input_tensor.dtype, input_tensor.device
|
||||||
|
min_dtype = torch.finfo(dtype).min
|
||||||
|
sequence_length = input_tensor.shape[1]
|
||||||
|
if using_static_cache:
|
||||||
|
target_length = past_key_values.get_max_length()
|
||||||
|
else:
|
||||||
|
target_length = (
|
||||||
|
attention_mask.shape[-1]
|
||||||
|
if isinstance(attention_mask, torch.Tensor)
|
||||||
|
else past_seen_tokens + sequence_length + 1
|
||||||
|
)
|
||||||
|
|
||||||
|
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||||
|
causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
|
attention_mask,
|
||||||
|
sequence_length=sequence_length,
|
||||||
|
target_length=target_length,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
min_dtype=min_dtype,
|
||||||
|
cache_position=cache_position,
|
||||||
|
batch_size=input_tensor.shape[0],
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.config._attn_implementation == "sdpa"
|
||||||
|
and attention_mask is not None
|
||||||
|
and attention_mask.device.type == "cuda"
|
||||||
|
and not output_attentions
|
||||||
|
):
|
||||||
|
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||||
|
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||||
|
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||||
|
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||||
|
|
||||||
|
return causal_mask
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"""
|
"""
|
||||||
@ -936,26 +1063,31 @@ class GPTJForCausalLM(GPTJPreTrainedModel):
|
|||||||
def set_output_embeddings(self, new_embeddings):
|
def set_output_embeddings(self, new_embeddings):
|
||||||
self.lm_head = new_embeddings
|
self.lm_head = new_embeddings
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
|
# Copied from transformers.models.gpt_neo.modeling_gpt_neo.GPTNeoForCausalLM.prepare_inputs_for_generation
|
||||||
token_type_ids = kwargs.get("token_type_ids", None)
|
def prepare_inputs_for_generation(
|
||||||
# Omit tokens covered by past_key_values
|
self,
|
||||||
if past_key_values:
|
input_ids,
|
||||||
past_length = past_key_values[0][0].shape[2]
|
attention_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
past_key_values=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
cache_position=None,
|
||||||
|
use_cache=True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||||
|
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
||||||
|
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||||
|
if past_key_values is not None:
|
||||||
|
if inputs_embeds is not None: # Exception 1
|
||||||
|
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||||
|
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||||
|
input_ids = input_ids[:, cache_position]
|
||||||
|
|
||||||
# Some generation methods already pass only the last input ID
|
|
||||||
if input_ids.shape[1] > past_length:
|
|
||||||
remove_prefix_length = past_length
|
|
||||||
else:
|
|
||||||
# Default to old behavior: keep only final ID
|
|
||||||
remove_prefix_length = input_ids.shape[1] - 1
|
|
||||||
|
|
||||||
input_ids = input_ids[:, remove_prefix_length:]
|
|
||||||
if token_type_ids is not None:
|
if token_type_ids is not None:
|
||||||
token_type_ids = token_type_ids[:, -input_ids.shape[1] :]
|
token_type_ids = token_type_ids[:, -input_ids.shape[1] :]
|
||||||
|
|
||||||
attention_mask = kwargs.get("attention_mask", None)
|
|
||||||
position_ids = kwargs.get("position_ids", None)
|
|
||||||
|
|
||||||
if attention_mask is not None and position_ids is None:
|
if attention_mask is not None and position_ids is None:
|
||||||
# create position_ids on the fly for batch generation
|
# create position_ids on the fly for batch generation
|
||||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||||
@ -963,22 +1095,47 @@ class GPTJForCausalLM(GPTJPreTrainedModel):
|
|||||||
if past_key_values:
|
if past_key_values:
|
||||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||||
|
|
||||||
|
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
|
||||||
|
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
|
||||||
|
|
||||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||||
if inputs_embeds is not None and past_key_values is None:
|
if inputs_embeds is not None and cache_position[0] == 0:
|
||||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||||
else:
|
else:
|
||||||
model_inputs = {"input_ids": input_ids}
|
model_inputs = {"input_ids": input_ids}
|
||||||
|
|
||||||
|
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
|
||||||
|
if inputs_embeds is not None:
|
||||||
|
batch_size, sequence_length = inputs_embeds.shape
|
||||||
|
device = inputs_embeds.device
|
||||||
|
else:
|
||||||
|
batch_size, sequence_length = input_ids.shape
|
||||||
|
device = input_ids.device
|
||||||
|
|
||||||
|
dtype = self.lm_head.weight.dtype
|
||||||
|
min_dtype = torch.finfo(dtype).min
|
||||||
|
|
||||||
|
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
|
attention_mask,
|
||||||
|
sequence_length=sequence_length,
|
||||||
|
target_length=past_key_values.get_max_length(),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
min_dtype=min_dtype,
|
||||||
|
cache_position=cache_position,
|
||||||
|
batch_size=batch_size,
|
||||||
|
)
|
||||||
|
|
||||||
model_inputs.update(
|
model_inputs.update(
|
||||||
{
|
{
|
||||||
"past_key_values": past_key_values,
|
|
||||||
"use_cache": kwargs.get("use_cache"),
|
|
||||||
"position_ids": position_ids,
|
"position_ids": position_ids,
|
||||||
"attention_mask": attention_mask,
|
"cache_position": cache_position,
|
||||||
|
"past_key_values": past_key_values,
|
||||||
|
"use_cache": use_cache,
|
||||||
"token_type_ids": token_type_ids,
|
"token_type_ids": token_type_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(GPTJ_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
@add_start_docstrings_to_model_forward(GPTJ_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||||
@ -991,7 +1148,7 @@ class GPTJForCausalLM(GPTJPreTrainedModel):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor]]]] = None,
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
token_type_ids: Optional[torch.LongTensor] = None,
|
token_type_ids: Optional[torch.LongTensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
@ -1002,6 +1159,7 @@ class GPTJForCausalLM(GPTJPreTrainedModel):
|
|||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
r"""
|
r"""
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
@ -1023,6 +1181,7 @@ class GPTJForCausalLM(GPTJPreTrainedModel):
|
|||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
)
|
)
|
||||||
hidden_states = transformer_outputs[0]
|
hidden_states = transformer_outputs[0]
|
||||||
|
|
||||||
|
@ -30,7 +30,8 @@ from torch.nn import CrossEntropyLoss
|
|||||||
|
|
||||||
from ... import PreTrainedModel
|
from ... import PreTrainedModel
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask_for_sdpa
|
from ...cache_utils import Cache, DynamicCache, StaticCache
|
||||||
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from ...modeling_outputs import ModelOutput
|
from ...modeling_outputs import ModelOutput
|
||||||
from ...modeling_utils import PretrainedConfig
|
from ...modeling_utils import PretrainedConfig
|
||||||
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
|
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
|
||||||
@ -50,6 +51,60 @@ logger = logging.get_logger(__name__)
|
|||||||
_CONFIG_FOR_DOC = "IdeficsConfig"
|
_CONFIG_FOR_DOC = "IdeficsConfig"
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
|
||||||
|
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
sequence_length: int,
|
||||||
|
target_length: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
min_dtype: float,
|
||||||
|
cache_position: torch.Tensor,
|
||||||
|
batch_size: int,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||||
|
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
attention_mask (`torch.Tensor`):
|
||||||
|
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
||||||
|
sequence_length (`int`):
|
||||||
|
The sequence length being processed.
|
||||||
|
target_length (`int`):
|
||||||
|
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
|
||||||
|
dtype (`torch.dtype`):
|
||||||
|
The dtype to use for the 4D attention mask.
|
||||||
|
device (`torch.device`):
|
||||||
|
The device to plcae the 4D attention mask on.
|
||||||
|
min_dtype (`float`):
|
||||||
|
The minimum value representable with the dtype `dtype`.
|
||||||
|
cache_position (`torch.Tensor`):
|
||||||
|
Indices depicting the position of the input sequence tokens in the sequence.
|
||||||
|
batch_size (`torch.Tensor`):
|
||||||
|
Batch size.
|
||||||
|
"""
|
||||||
|
if attention_mask is not None and attention_mask.dim() == 4:
|
||||||
|
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||||
|
causal_mask = attention_mask
|
||||||
|
else:
|
||||||
|
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
||||||
|
if sequence_length != 1:
|
||||||
|
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||||
|
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||||
|
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||||
|
if attention_mask is not None:
|
||||||
|
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||||
|
mask_length = attention_mask.shape[-1]
|
||||||
|
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
||||||
|
padding_mask = padding_mask == 0
|
||||||
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||||
|
padding_mask, min_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
return causal_mask
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class IdeficsBaseModelOutputWithPast(ModelOutput):
|
class IdeficsBaseModelOutputWithPast(ModelOutput):
|
||||||
"""
|
"""
|
||||||
@ -184,11 +239,13 @@ def expand_inputs_for_generation(
|
|||||||
|
|
||||||
def prepare_inputs_for_generation(input_ids, past_key_values=None, **kwargs):
|
def prepare_inputs_for_generation(input_ids, past_key_values=None, **kwargs):
|
||||||
token_type_ids = kwargs.get("token_type_ids", None)
|
token_type_ids = kwargs.get("token_type_ids", None)
|
||||||
# only last token for inputs_ids if past is defined in kwargs
|
cache_position = kwargs.get("cache_position", None)
|
||||||
if past_key_values:
|
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
if past_key_values is not None:
|
||||||
|
if input_ids.shape[1] != cache_position.shape[0]:
|
||||||
|
input_ids = input_ids[:, cache_position]
|
||||||
if token_type_ids is not None:
|
if token_type_ids is not None:
|
||||||
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
|
token_type_ids = token_type_ids[:, -input_ids.shape[1] :]
|
||||||
|
|
||||||
attention_mask = kwargs.get("attention_mask", None)
|
attention_mask = kwargs.get("attention_mask", None)
|
||||||
position_ids = kwargs.get("position_ids", None)
|
position_ids = kwargs.get("position_ids", None)
|
||||||
@ -200,6 +257,9 @@ def prepare_inputs_for_generation(input_ids, past_key_values=None, **kwargs):
|
|||||||
if past_key_values:
|
if past_key_values:
|
||||||
position_ids = position_ids[:, -1].unsqueeze(-1)
|
position_ids = position_ids[:, -1].unsqueeze(-1)
|
||||||
|
|
||||||
|
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
|
||||||
|
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
|
||||||
|
|
||||||
pixel_values = kwargs.get("pixel_values", None)
|
pixel_values = kwargs.get("pixel_values", None)
|
||||||
image_encoder_embeddings = kwargs.get("image_encoder_embeddings", None)
|
image_encoder_embeddings = kwargs.get("image_encoder_embeddings", None)
|
||||||
perceiver_embeddings = kwargs.get("perceiver_embeddings", None)
|
perceiver_embeddings = kwargs.get("perceiver_embeddings", None)
|
||||||
@ -210,6 +270,7 @@ def prepare_inputs_for_generation(input_ids, past_key_values=None, **kwargs):
|
|||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"past_key_values": past_key_values,
|
"past_key_values": past_key_values,
|
||||||
"use_cache": kwargs.get("use_cache"),
|
"use_cache": kwargs.get("use_cache"),
|
||||||
|
"cache_position": cache_position,
|
||||||
"position_ids": position_ids,
|
"position_ids": position_ids,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"token_type_ids": token_type_ids,
|
"token_type_ids": token_type_ids,
|
||||||
@ -541,6 +602,7 @@ class IdeficsAttention(nn.Module):
|
|||||||
is_cross_attention: bool = False,
|
is_cross_attention: bool = False,
|
||||||
config: PretrainedConfig = None,
|
config: PretrainedConfig = None,
|
||||||
qk_layer_norms: bool = False,
|
qk_layer_norms: bool = False,
|
||||||
|
layer_idx: int = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
@ -549,6 +611,14 @@ class IdeficsAttention(nn.Module):
|
|||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
self.is_causal = True
|
self.is_causal = True
|
||||||
|
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
if layer_idx is None:
|
||||||
|
logger.warning_once(
|
||||||
|
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
|
||||||
|
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
|
||||||
|
"when creating this class."
|
||||||
|
)
|
||||||
|
|
||||||
if (self.head_dim * num_heads) != self.hidden_size:
|
if (self.head_dim * num_heads) != self.hidden_size:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
||||||
@ -615,6 +685,7 @@ class IdeficsAttention(nn.Module):
|
|||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||||
is_cross_attention = self.is_cross_attention or key_value_states is not None
|
is_cross_attention = self.is_cross_attention or key_value_states is not None
|
||||||
@ -634,18 +705,17 @@ class IdeficsAttention(nn.Module):
|
|||||||
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
kv_seq_len = key_states.shape[-2]
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
kv_seq_len += past_key_value[0].shape[-2]
|
kv_seq_len += cache_position[0]
|
||||||
|
|
||||||
if not is_cross_attention:
|
if not is_cross_attention:
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=max(kv_seq_len, q_len))
|
cos, sin = self.rotary_emb(value_states, seq_len=max(kv_seq_len, q_len))
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||||
# [bsz, nh, t, hd]
|
# [bsz, nh, t, hd]
|
||||||
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
# reuse k, v, self_attention
|
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
cache_kwargs = {"cache_position": cache_position}
|
||||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
past_key_value = (key_states, value_states) if use_cache else None
|
|
||||||
|
|
||||||
if self.qk_layer_norms:
|
if self.qk_layer_norms:
|
||||||
query_states = self.q_layer_norm(query_states)
|
query_states = self.q_layer_norm(query_states)
|
||||||
@ -700,7 +770,7 @@ class IdeficsAttention(nn.Module):
|
|||||||
|
|
||||||
# this was adapted from LlamaDecoderLayer
|
# this was adapted from LlamaDecoderLayer
|
||||||
class IdeficsDecoderLayer(nn.Module):
|
class IdeficsDecoderLayer(nn.Module):
|
||||||
def __init__(self, config: IdeficsConfig):
|
def __init__(self, config: IdeficsConfig, layer_idx: int = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.self_attn = IdeficsAttention(
|
self.self_attn = IdeficsAttention(
|
||||||
@ -708,6 +778,7 @@ class IdeficsDecoderLayer(nn.Module):
|
|||||||
num_heads=config.num_attention_heads,
|
num_heads=config.num_attention_heads,
|
||||||
dropout=config.dropout,
|
dropout=config.dropout,
|
||||||
config=config,
|
config=config,
|
||||||
|
layer_idx=layer_idx,
|
||||||
)
|
)
|
||||||
self.mlp = IdeficsMLP(
|
self.mlp = IdeficsMLP(
|
||||||
hidden_size=self.hidden_size,
|
hidden_size=self.hidden_size,
|
||||||
@ -726,6 +797,7 @@ class IdeficsDecoderLayer(nn.Module):
|
|||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -753,6 +825,7 @@ class IdeficsDecoderLayer(nn.Module):
|
|||||||
past_key_value=past_key_value,
|
past_key_value=past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
|
cache_position=cache_position,
|
||||||
)
|
)
|
||||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
@ -944,6 +1017,7 @@ class IdeficsPreTrainedModel(PreTrainedModel):
|
|||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
_no_split_modules = ["IdeficsDecoderLayer", "IdeficsGatedCrossAttentionLayer"]
|
_no_split_modules = ["IdeficsDecoderLayer", "IdeficsGatedCrossAttentionLayer"]
|
||||||
_supports_sdpa = True
|
_supports_sdpa = True
|
||||||
|
_supports_cache_class = True
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
# important: this ported version of Idefics isn't meant for training from scratch - only
|
# important: this ported version of Idefics isn't meant for training from scratch - only
|
||||||
@ -1031,6 +1105,10 @@ LLAMA_INPUTS_DOCSTRING = r"""
|
|||||||
more detail.
|
more detail.
|
||||||
return_dict (`bool`, *optional*):
|
return_dict (`bool`, *optional*):
|
||||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
|
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||||
|
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
|
||||||
|
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
|
||||||
|
the complete sequence length.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -1076,7 +1154,9 @@ class IdeficsModel(IdeficsPreTrainedModel):
|
|||||||
perceiver_config.resampler_n_latents,
|
perceiver_config.resampler_n_latents,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.layers = nn.ModuleList([IdeficsDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
self.layers = nn.ModuleList(
|
||||||
|
[IdeficsDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]
|
||||||
|
)
|
||||||
|
|
||||||
self.cross_layer_interval = config.cross_layer_interval
|
self.cross_layer_interval = config.cross_layer_interval
|
||||||
num_cross_layers = config.num_hidden_layers // self.cross_layer_interval
|
num_cross_layers = config.num_hidden_layers // self.cross_layer_interval
|
||||||
@ -1132,6 +1212,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
|
|||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
interpolate_pos_encoding: Optional[bool] = False,
|
interpolate_pos_encoding: Optional[bool] = False,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
) -> Union[Tuple, IdeficsBaseModelOutputWithPast]:
|
) -> Union[Tuple, IdeficsBaseModelOutputWithPast]:
|
||||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
|
||||||
@ -1143,22 +1224,38 @@ class IdeficsModel(IdeficsPreTrainedModel):
|
|||||||
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
# retrieve input_ids and inputs_embeds
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||||
if input_ids is not None and inputs_embeds is not None:
|
raise ValueError(
|
||||||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
||||||
elif input_ids is not None:
|
)
|
||||||
batch_size, seq_length = input_ids.shape
|
|
||||||
elif inputs_embeds is not None:
|
|
||||||
batch_size, seq_length, _ = inputs_embeds.shape
|
|
||||||
else:
|
|
||||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
|
||||||
|
|
||||||
seq_length_with_past = seq_length
|
if self.gradient_checkpointing and self.training and use_cache:
|
||||||
past_key_values_length = 0
|
logger.warning_once(
|
||||||
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
||||||
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
if past_key_values is not None:
|
if inputs_embeds is None:
|
||||||
past_key_values_length = past_key_values[0][0].shape[2]
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
|
||||||
|
return_legacy_cache = False
|
||||||
|
if use_cache and not isinstance(past_key_values, Cache):
|
||||||
|
if not self.training:
|
||||||
|
logger.warning_once(
|
||||||
|
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.45. "
|
||||||
|
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
|
||||||
|
)
|
||||||
|
return_legacy_cache = True
|
||||||
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
|
|
||||||
|
batch_size, seq_length, _ = inputs_embeds.shape
|
||||||
|
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
|
seq_length_with_past = seq_length + past_key_values_length
|
||||||
|
|
||||||
|
if cache_position is None:
|
||||||
|
cache_position = torch.arange(
|
||||||
|
past_key_values_length, past_key_values_length + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||||
|
)
|
||||||
|
|
||||||
if attention_mask is not None and position_ids is None:
|
if attention_mask is not None and position_ids is None:
|
||||||
# create position_ids on the fly for batch generation
|
# create position_ids on the fly for batch generation
|
||||||
@ -1229,37 +1326,27 @@ class IdeficsModel(IdeficsPreTrainedModel):
|
|||||||
device
|
device
|
||||||
)
|
)
|
||||||
|
|
||||||
if inputs_embeds is None:
|
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
|
||||||
# embed positions
|
# embed positions
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = torch.ones(
|
attention_mask = torch.ones(
|
||||||
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
|
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
|
||||||
)
|
)
|
||||||
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
|
||||||
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
attention_mask = self._update_causal_mask(
|
||||||
|
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
|
||||||
if use_cache:
|
|
||||||
logger.warning_once(
|
|
||||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
|
||||||
)
|
|
||||||
use_cache = False
|
|
||||||
|
|
||||||
# decoder layers
|
# decoder layers
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
all_self_attns = () if output_attentions else None
|
all_self_attns = () if output_attentions else None
|
||||||
next_decoder_cache = () if use_cache else None
|
next_decoder_cache = None
|
||||||
|
|
||||||
for idx, decoder_layer in enumerate(self.layers):
|
for idx, decoder_layer in enumerate(self.layers):
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
|
||||||
|
|
||||||
def vblock(
|
def vblock(
|
||||||
main_block,
|
main_block,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@ -1274,6 +1361,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
|
|||||||
layer_idx,
|
layer_idx,
|
||||||
cross_layer_interval,
|
cross_layer_interval,
|
||||||
gated_cross_attn_layers,
|
gated_cross_attn_layers,
|
||||||
|
cache_position,
|
||||||
):
|
):
|
||||||
# TODO(ls): Add cross attention values to respective lists
|
# TODO(ls): Add cross attention values to respective lists
|
||||||
if layer_idx % cross_layer_interval == 0:
|
if layer_idx % cross_layer_interval == 0:
|
||||||
@ -1297,12 +1385,13 @@ class IdeficsModel(IdeficsPreTrainedModel):
|
|||||||
past_key_value=past_key_value,
|
past_key_value=past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
|
cache_position=cache_position,
|
||||||
)
|
)
|
||||||
|
|
||||||
return layer_outputs
|
return layer_outputs
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
past_key_value = None
|
past_key_values = None
|
||||||
if use_cache:
|
if use_cache:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||||
@ -1315,7 +1404,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
|
|||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
position_ids,
|
position_ids,
|
||||||
past_key_value,
|
past_key_values,
|
||||||
image_hidden_states,
|
image_hidden_states,
|
||||||
image_attention_mask,
|
image_attention_mask,
|
||||||
cross_attention_gate,
|
cross_attention_gate,
|
||||||
@ -1324,6 +1413,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
|
|||||||
idx,
|
idx,
|
||||||
self.cross_layer_interval,
|
self.cross_layer_interval,
|
||||||
self.gated_cross_attn_layers,
|
self.gated_cross_attn_layers,
|
||||||
|
cache_position,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
layer_outputs = vblock(
|
layer_outputs = vblock(
|
||||||
@ -1331,7 +1421,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
|
|||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
past_key_value=past_key_value,
|
past_key_value=past_key_values,
|
||||||
image_hidden_states=image_hidden_states,
|
image_hidden_states=image_hidden_states,
|
||||||
image_attention_mask=image_attention_mask,
|
image_attention_mask=image_attention_mask,
|
||||||
cross_attention_gate=cross_attention_gate,
|
cross_attention_gate=cross_attention_gate,
|
||||||
@ -1340,12 +1430,13 @@ class IdeficsModel(IdeficsPreTrainedModel):
|
|||||||
layer_idx=idx,
|
layer_idx=idx,
|
||||||
cross_layer_interval=self.cross_layer_interval,
|
cross_layer_interval=self.cross_layer_interval,
|
||||||
gated_cross_attn_layers=self.gated_cross_attn_layers,
|
gated_cross_attn_layers=self.gated_cross_attn_layers,
|
||||||
|
cache_position=cache_position,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
all_self_attns += (layer_outputs[1],)
|
all_self_attns += (layer_outputs[1],)
|
||||||
@ -1357,6 +1448,8 @@ class IdeficsModel(IdeficsPreTrainedModel):
|
|||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
next_cache = next_decoder_cache if use_cache else None
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
|
if return_legacy_cache:
|
||||||
|
next_cache = next_cache.to_legacy_cache()
|
||||||
image_hidden_states = image_hidden_states.view(batch_size, num_images, image_seq_len, image_hidden_size)
|
image_hidden_states = image_hidden_states.view(batch_size, num_images, image_seq_len, image_hidden_size)
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(
|
return tuple(
|
||||||
@ -1372,6 +1465,78 @@ class IdeficsModel(IdeficsPreTrainedModel):
|
|||||||
image_hidden_states=image_hidden_states,
|
image_hidden_states=image_hidden_states,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
|
||||||
|
def _update_causal_mask(
|
||||||
|
self,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
input_tensor: torch.Tensor,
|
||||||
|
cache_position: torch.Tensor,
|
||||||
|
past_key_values: Cache,
|
||||||
|
output_attentions: bool,
|
||||||
|
):
|
||||||
|
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
|
||||||
|
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
|
||||||
|
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
|
||||||
|
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
|
||||||
|
|
||||||
|
if self.config._attn_implementation == "flash_attention_2":
|
||||||
|
if attention_mask is not None and 0.0 in attention_mask:
|
||||||
|
return attention_mask
|
||||||
|
return None
|
||||||
|
|
||||||
|
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||||
|
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||||
|
# to infer the attention mask.
|
||||||
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
|
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||||
|
|
||||||
|
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||||
|
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
|
||||||
|
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||||
|
attention_mask,
|
||||||
|
inputs_embeds=input_tensor,
|
||||||
|
past_key_values_length=past_seen_tokens,
|
||||||
|
is_training=self.training,
|
||||||
|
):
|
||||||
|
return None
|
||||||
|
|
||||||
|
dtype, device = input_tensor.dtype, input_tensor.device
|
||||||
|
min_dtype = torch.finfo(dtype).min
|
||||||
|
sequence_length = input_tensor.shape[1]
|
||||||
|
if using_static_cache:
|
||||||
|
target_length = past_key_values.get_max_length()
|
||||||
|
else:
|
||||||
|
target_length = (
|
||||||
|
attention_mask.shape[-1]
|
||||||
|
if isinstance(attention_mask, torch.Tensor)
|
||||||
|
else past_seen_tokens + sequence_length + 1
|
||||||
|
)
|
||||||
|
|
||||||
|
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||||
|
causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
|
attention_mask,
|
||||||
|
sequence_length=sequence_length,
|
||||||
|
target_length=target_length,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
min_dtype=min_dtype,
|
||||||
|
cache_position=cache_position,
|
||||||
|
batch_size=input_tensor.shape[0],
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.config._attn_implementation == "sdpa"
|
||||||
|
and attention_mask is not None
|
||||||
|
and attention_mask.device.type == "cuda"
|
||||||
|
and not output_attentions
|
||||||
|
):
|
||||||
|
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||||
|
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||||
|
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||||
|
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||||
|
|
||||||
|
return causal_mask
|
||||||
|
|
||||||
|
|
||||||
class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
|
class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
|
||||||
_keys_to_ignore_on_load_missing = [r"lm_head.weight"]
|
_keys_to_ignore_on_load_missing = [r"lm_head.weight"]
|
||||||
@ -1450,6 +1615,7 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
|
|||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
interpolate_pos_encoding: Optional[bool] = False,
|
interpolate_pos_encoding: Optional[bool] = False,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
) -> Union[Tuple, IdeficsCausalLMOutputWithPast]:
|
) -> Union[Tuple, IdeficsCausalLMOutputWithPast]:
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
@ -1508,6 +1674,7 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
|
|||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
hidden_states = outputs[0]
|
||||||
@ -1567,13 +1734,13 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
|
|||||||
outputs: ModelOutput,
|
outputs: ModelOutput,
|
||||||
model_kwargs: Dict[str, Any],
|
model_kwargs: Dict[str, Any],
|
||||||
is_encoder_decoder: bool = False,
|
is_encoder_decoder: bool = False,
|
||||||
standardize_cache_format: bool = False,
|
**kwargs,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
model_kwargs = super()._update_model_kwargs_for_generation(
|
model_kwargs = super()._update_model_kwargs_for_generation(
|
||||||
outputs,
|
outputs,
|
||||||
model_kwargs,
|
model_kwargs,
|
||||||
is_encoder_decoder,
|
is_encoder_decoder,
|
||||||
standardize_cache_format,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if "image_attention_mask" in model_kwargs:
|
if "image_attention_mask" in model_kwargs:
|
||||||
|
@ -59,7 +59,7 @@ if is_torch_available():
|
|||||||
ImageGPTForCausalImageModeling,
|
ImageGPTForCausalImageModeling,
|
||||||
SpeechEncoderDecoderModel,
|
SpeechEncoderDecoderModel,
|
||||||
)
|
)
|
||||||
from transformers.cache_utils import DynamicCache, EncoderDecoderCache, QuantoQuantizedCache
|
from transformers.cache_utils import DynamicCache, EncoderDecoderCache, QuantoQuantizedCache, StaticCache
|
||||||
from transformers.generation import (
|
from transformers.generation import (
|
||||||
BeamSampleDecoderOnlyOutput,
|
BeamSampleDecoderOnlyOutput,
|
||||||
BeamSampleEncoderDecoderOutput,
|
BeamSampleEncoderDecoderOutput,
|
||||||
@ -1769,6 +1769,53 @@ class GenerationTesterMixin:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_generate_with_static_cache(self):
|
||||||
|
"""
|
||||||
|
Tests if StaticCache works if we set attn_implementation=static when generation.
|
||||||
|
This doesn't test if generation quality is good, but tests that models with
|
||||||
|
self._supports_static_cache don't throw an error when generating and return
|
||||||
|
a StaticCache object at the end.
|
||||||
|
"""
|
||||||
|
for model_class in self.all_generative_model_classes:
|
||||||
|
if not model_class._supports_static_cache:
|
||||||
|
self.skipTest(reason="This model does not support the static cache format")
|
||||||
|
|
||||||
|
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||||
|
if config.is_encoder_decoder:
|
||||||
|
self.skipTest(reason="This model is encoder-decoder and has Encoder-Decoder Cache")
|
||||||
|
|
||||||
|
config.use_cache = True
|
||||||
|
config.is_decoder = True
|
||||||
|
batch_size, seq_length = input_ids.shape
|
||||||
|
max_new_tokens = 20
|
||||||
|
|
||||||
|
model = model_class(config).to(torch_device).eval()
|
||||||
|
generation_kwargs = {
|
||||||
|
"max_length": None,
|
||||||
|
"max_new_tokens": max_new_tokens,
|
||||||
|
"cache_implementation": "static",
|
||||||
|
"return_dict_in_generate": True, # Required to return `past_key_values`
|
||||||
|
}
|
||||||
|
|
||||||
|
max_cache_len = seq_length + max_new_tokens
|
||||||
|
head_dim = (
|
||||||
|
model.config.head_dim
|
||||||
|
if hasattr(model.config, "head_dim")
|
||||||
|
else model.config.hidden_size // model.config.num_attention_heads
|
||||||
|
)
|
||||||
|
num_key_value_heads = (
|
||||||
|
model.config.num_attention_heads
|
||||||
|
if getattr(config, "num_key_value_heads", None) is None
|
||||||
|
else model.config.num_key_value_heads
|
||||||
|
)
|
||||||
|
num_hidden_layers = config.num_hidden_layers
|
||||||
|
results = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
||||||
|
|
||||||
|
cache_shape = (batch_size, num_key_value_heads, max_cache_len, head_dim)
|
||||||
|
self.assertTrue(isinstance(results.past_key_values, StaticCache))
|
||||||
|
self.assertTrue(len(results.past_key_values.key_cache) == num_hidden_layers)
|
||||||
|
self.assertTrue(results.past_key_values.key_cache[0].shape == cache_shape)
|
||||||
|
|
||||||
@require_quanto
|
@require_quanto
|
||||||
def test_generate_with_quant_cache(self):
|
def test_generate_with_quant_cache(self):
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
|
@ -4587,6 +4587,44 @@ class ModelTesterMixin:
|
|||||||
normalized_1 = F.softmax(out_shared_prefix_last_tokens)
|
normalized_1 = F.softmax(out_shared_prefix_last_tokens)
|
||||||
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)
|
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)
|
||||||
|
|
||||||
|
def test_static_cache_matches_dynamic(self):
|
||||||
|
"""
|
||||||
|
Tests that generating with static cache give almost same results as with dynamic cache.
|
||||||
|
This test does not compile the model and check only logits similarity for numerical precision
|
||||||
|
errors.
|
||||||
|
"""
|
||||||
|
if len(self.all_generative_model_classes) == 0:
|
||||||
|
self.skipTest(
|
||||||
|
reason="Model architecture has no generative classes, and thus not necessarily supporting 4D masks"
|
||||||
|
)
|
||||||
|
|
||||||
|
for model_class in self.all_generative_model_classes:
|
||||||
|
if not model_class._supports_static_cache:
|
||||||
|
self.skipTest(f"{model_class.__name__} does not support static cache")
|
||||||
|
|
||||||
|
if not model_class._supports_cache_class:
|
||||||
|
self.skipTest(f"{model_class.__name__} does not support cache class")
|
||||||
|
|
||||||
|
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
if getattr(config, "sliding_window", 0) > 0:
|
||||||
|
self.skipTest(f"{model_class.__name__} with sliding window attention is not supported by this test")
|
||||||
|
|
||||||
|
model = model_class(config).to(device=torch_device, dtype=torch.float32)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
dynamic_out = model.generate(
|
||||||
|
**inputs, do_sample=False, max_new_tokens=10, output_logits=True, return_dict_in_generate=True
|
||||||
|
)
|
||||||
|
static_out = model.generate(
|
||||||
|
**inputs,
|
||||||
|
do_sample=False,
|
||||||
|
max_new_tokens=10,
|
||||||
|
cache_implementation="static",
|
||||||
|
output_logits=True,
|
||||||
|
return_dict_in_generate=True,
|
||||||
|
)
|
||||||
|
self.assertTrue(torch.allclose(dynamic_out.logits[0], static_out.logits[0], rtol=1e-3, atol=1e-4))
|
||||||
|
|
||||||
# For now, Let's focus only on GPU for `torch.compile`
|
# For now, Let's focus only on GPU for `torch.compile`
|
||||||
@slow
|
@slow
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
|
Loading…
Reference in New Issue
Block a user