mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Zamba new attention standard (#35375)
* updated zamba to new attention standard * make fixup fixes
This commit is contained in:
parent
12ba96aa3c
commit
bd442c6d3a
@ -20,7 +20,7 @@
|
||||
"""PyTorch Zamba model."""
|
||||
|
||||
import math
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
@ -33,18 +33,18 @@ from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import (
|
||||
AttentionMaskConverter,
|
||||
)
|
||||
from ...modeling_flash_attention_utils import _flash_attention_forward
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
SequenceClassifierOutputWithPast,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
@ -113,7 +113,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
class HybridMambaAttentionDynamicCache(DynamicCache):
|
||||
class ZambaHybridDynamicCache(DynamicCache):
|
||||
"""
|
||||
A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache
|
||||
(which has a constant shape regardless of seq_len).
|
||||
@ -131,9 +131,9 @@ class HybridMambaAttentionDynamicCache(DynamicCache):
|
||||
self.dtype = dtype
|
||||
self.layers_block_type = config.layers_block_type
|
||||
self.has_previous_state = False # only used by mamba
|
||||
intermediate_size = config.mamba_expand * config.hidden_size
|
||||
ssm_state_size = config.mamba_d_state
|
||||
conv_kernel_size = config.mamba_d_conv
|
||||
self.intermediate_size = config.mamba_expand * config.hidden_size
|
||||
self.ssm_state_size = config.mamba_d_state
|
||||
self.conv_kernel_size = config.mamba_d_conv
|
||||
self.n_mamba_heads = config.n_mamba_heads
|
||||
self.conv_states = []
|
||||
self.ssm_states = []
|
||||
@ -143,9 +143,14 @@ class HybridMambaAttentionDynamicCache(DynamicCache):
|
||||
self._buffers = {}
|
||||
for i in range(config.num_hidden_layers):
|
||||
self.conv_states += [
|
||||
torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype)
|
||||
torch.zeros(batch_size, self.intermediate_size, self.conv_kernel_size, device=device, dtype=dtype)
|
||||
]
|
||||
cache_shape = (batch_size, self.n_mamba_heads, intermediate_size // self.n_mamba_heads, ssm_state_size)
|
||||
cache_shape = (
|
||||
batch_size,
|
||||
self.n_mamba_heads,
|
||||
self.intermediate_size // self.n_mamba_heads,
|
||||
self.ssm_state_size,
|
||||
)
|
||||
self.ssm_states += [torch.zeros(cache_shape, device=device, dtype=dtype)]
|
||||
if self.layers_block_type[i] == "hybrid":
|
||||
self.transformer_layers.append(i)
|
||||
@ -194,14 +199,38 @@ class HybridMambaAttentionDynamicCache(DynamicCache):
|
||||
return 0
|
||||
return self.key_cache[layer_idx].shape[-2]
|
||||
|
||||
# Copied from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache.to_legacy_cache
|
||||
def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
|
||||
raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.")
|
||||
raise NotImplementedError("ZambaHybridDynamicCache does not have a legacy cache equivalent.")
|
||||
|
||||
@classmethod
|
||||
# Copied from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache.from_legacy_cache
|
||||
def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
|
||||
raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.")
|
||||
raise NotImplementedError("ZambaHybridDynamicCache does not have a legacy cache equivalent.")
|
||||
|
||||
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
key_states = repeat_kv(key, module.num_key_value_groups)
|
||||
value_states = repeat_kv(value, module.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class ZambaAttention(nn.Module):
|
||||
@ -218,277 +247,67 @@ class ZambaAttention(nn.Module):
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim/2)
|
||||
"""
|
||||
|
||||
def __init__(self, config: ZambaConfig, layer_idx: Optional[int] = None):
|
||||
def __init__(self, config: ZambaConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
self.attention_hidden_size = config.attention_hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = config.attention_head_dim
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.scaling = (self.head_dim / 2) ** -0.5
|
||||
self.is_causal = True
|
||||
self.attention_dropout = config.attention_dropout
|
||||
|
||||
if (self.head_dim * self.num_heads) != self.attention_hidden_size:
|
||||
raise ValueError(
|
||||
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
||||
f" and `num_heads`: {self.num_heads})."
|
||||
)
|
||||
self.q_proj = nn.Linear(self.attention_hidden_size, self.num_heads * self.head_dim, bias=False)
|
||||
self.k_proj = nn.Linear(self.attention_hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
||||
self.v_proj = nn.Linear(self.attention_hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
||||
self.q_proj = nn.Linear(config.attention_hidden_size, config.num_attention_heads * self.head_dim, bias=False)
|
||||
self.k_proj = nn.Linear(config.attention_hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
|
||||
self.v_proj = nn.Linear(config.attention_hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
|
||||
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
layer_idx: int,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
past_key_value: Optional[ZambaHybridDynamicCache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
if past_key_value is not None:
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, layer_idx)
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim / 2)
|
||||
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.attention_hidden_size)
|
||||
|
||||
attn_output = attn_output
|
||||
attn_output = self.o_proj(attn_output)
|
||||
attn_output = attn_output
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
# Adapted from transformers.models.mistral.modeling_mistral.MistralAttention:
|
||||
# Added softmax_scale = 1 / (query_states.shape[-1]/2)**0.5 to the arguments of self._flash_attention_forward
|
||||
# dropped use_sliding_windows from the arguments of self._flash_attention_forward
|
||||
class ZambaFlashAttention2(ZambaAttention):
|
||||
"""
|
||||
Zamba flash attention module. This module inherits from `ZambaAttention` as the weights of the module stays
|
||||
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
||||
flash attention and deal with padding tokens in case the input contains any of them.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
layer_idx: int,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
# Flash attention requires the input to have the shape
|
||||
# batch_size x seq_length x head_dim x hidden_dim
|
||||
# therefore we just need to keep the original shape
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if past_key_value is not None:
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, layer_idx)
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
dropout_rate = 0.0 if not self.training else self.attention_dropout
|
||||
|
||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
# cast them back in float16 just to be sure everything works as expected.
|
||||
input_dtype = query_states.dtype
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
||||
logger.warning_once(
|
||||
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
||||
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
else:
|
||||
target_dtype = self.q_proj.weight.dtype
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
logger.warning_once(
|
||||
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||
f" {target_dtype}."
|
||||
)
|
||||
|
||||
query_states = query_states.to(target_dtype)
|
||||
key_states = key_states.to(target_dtype)
|
||||
value_states = value_states.to(target_dtype)
|
||||
|
||||
# Reashape to the expected shape for Flash Attention
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
softmax_scale = 1 / math.sqrt(self.head_dim / 2)
|
||||
|
||||
attn_output = _flash_attention_forward(
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
q_len,
|
||||
dropout=dropout_rate,
|
||||
softmax_scale=softmax_scale,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.attention_hidden_size).contiguous()
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
# Adapted from transformers.models.mistral.modeling_mistral.MistralAttention:
|
||||
# added scale = 1 / (query_states.shape[-1]/2)**0.5 to the arguments of torch.nn.functional.scaled_dot_product_attention
|
||||
class ZambaSdpaAttention(ZambaAttention):
|
||||
"""
|
||||
Zamba attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
||||
`ZambaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
||||
SDPA API.
|
||||
"""
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
layer_idx: int,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if output_attentions:
|
||||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||
logger.warning_once(
|
||||
"ZambaModel is using ZambaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
||||
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if past_key_value is not None:
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, layer_idx)
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
causal_mask = attention_mask
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
||||
|
||||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
||||
if query_states.device.type == "cuda" and attention_mask is not None:
|
||||
query_states = query_states.contiguous()
|
||||
key_states = key_states.contiguous()
|
||||
value_states = value_states.contiguous()
|
||||
|
||||
softmax_scale = 1 / math.sqrt(self.head_dim / 2)
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
|
||||
is_causal=self.is_causal and attention_mask is None and q_len > 1,
|
||||
scale=softmax_scale,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.view(bsz, q_len, self.attention_hidden_size)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
|
||||
ZAMBA_ATTENTION_CLASSES = {
|
||||
"eager": ZambaAttention,
|
||||
"flash_attention_2": ZambaFlashAttention2,
|
||||
"sdpa": ZambaSdpaAttention,
|
||||
}
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class ZambaMambaMixer(nn.Module):
|
||||
@ -568,7 +387,7 @@ class ZambaMambaMixer(nn.Module):
|
||||
)
|
||||
|
||||
def cuda_kernels_forward(
|
||||
self, hidden_states: torch.Tensor, cache_params: HybridMambaAttentionDynamicCache = None, attention_mask=None
|
||||
self, hidden_states: torch.Tensor, cache_params: ZambaHybridDynamicCache = None, attention_mask=None
|
||||
):
|
||||
batch_size, seq_len, _ = hidden_states.shape
|
||||
use_precomputed_states = cache_params is not None and cache_params.has_previous_state and seq_len == 1
|
||||
@ -664,7 +483,7 @@ class ZambaMambaMixer(nn.Module):
|
||||
contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))
|
||||
return contextualized_states
|
||||
|
||||
def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCache = None, attention_mask=None):
|
||||
def slow_forward(self, input_states, cache_params: ZambaHybridDynamicCache = None, attention_mask=None):
|
||||
batch_size, seq_len, _ = input_states.shape
|
||||
dtype = input_states.dtype
|
||||
# 1. Gated linear projection
|
||||
@ -675,7 +494,7 @@ class ZambaMambaMixer(nn.Module):
|
||||
gate = gate.squeeze(2)
|
||||
gate = gate.reshape(batch_size, self.n_mamba_heads, -1, seq_len).transpose(0, 1)
|
||||
|
||||
use_cache = isinstance(cache_params, HybridMambaAttentionDynamicCache)
|
||||
use_cache = isinstance(cache_params, ZambaHybridDynamicCache)
|
||||
# 2. Convolution sequence transformation
|
||||
if use_cache and cache_params.ssm_states[self.layer_idx].shape[0] == batch_size:
|
||||
if self.training:
|
||||
@ -757,7 +576,7 @@ class ZambaMambaMixer(nn.Module):
|
||||
)
|
||||
return contextualized_states
|
||||
|
||||
def forward(self, hidden_states, cache_params: HybridMambaAttentionDynamicCache = None, attention_mask=None):
|
||||
def forward(self, hidden_states, cache_params: ZambaHybridDynamicCache = None, attention_mask=None):
|
||||
if self.use_fast_kernels:
|
||||
if not is_fast_path_available or "cuda" not in self.x_proj_weight.device.type:
|
||||
raise ValueError(
|
||||
@ -789,7 +608,7 @@ class ZambaMLP(nn.Module):
|
||||
class ZambaAttentionDecoderLayer(nn.Module):
|
||||
def __init__(self, config: ZambaConfig, layer_idx: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.self_attn = ZAMBA_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
|
||||
self.self_attn = ZambaAttention(config, layer_idx)
|
||||
|
||||
self.feed_forward = ZambaMLP(config)
|
||||
self.input_layernorm = ZambaRMSNorm(config.attention_hidden_size, eps=config.rms_norm_eps)
|
||||
@ -802,11 +621,11 @@ class ZambaAttentionDecoderLayer(nn.Module):
|
||||
layer_idx: int,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
|
||||
past_key_value: Optional[ZambaHybridDynamicCache] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
Args:
|
||||
@ -815,9 +634,11 @@ class ZambaAttentionDecoderLayer(nn.Module):
|
||||
This is concatenated with `hidden_states` (which is the output of the previous (mamba) layer). The
|
||||
concatenated tensor is then used as input of the pre-attention RMSNorm
|
||||
(see fig. 2 in https://arxiv.org/pdf/2405.16712).
|
||||
layer_idx (`int`): layer_idx in the forward pass. Used to distinguish Zamba's tied transformer layers.
|
||||
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
||||
`(batch, sequence_length)` where padding elements are indicated by 0.
|
||||
past_key_value (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states
|
||||
position_ids (`torch.LongTensor`, *optional*): token positions of shape `(batch, seq_len)`. Used for positional encodings.
|
||||
past_key_value (`ZambaHybridDynamicCache`, *optional*): cached past key and value projection states
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
@ -829,7 +650,7 @@ class ZambaAttentionDecoderLayer(nn.Module):
|
||||
"""
|
||||
hidden_states = torch.concatenate([hidden_states, original_hidden_states], dim=-1)
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states, self_attn_weights = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
layer_idx=layer_idx,
|
||||
attention_mask=attention_mask,
|
||||
@ -849,9 +670,6 @@ class ZambaAttentionDecoderLayer(nn.Module):
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
if use_cache:
|
||||
outputs += (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@ -870,7 +688,7 @@ class ZambaMambaDecoderLayer(nn.Module):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
causal_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
|
||||
past_key_value: Optional[ZambaHybridDynamicCache] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
@ -881,7 +699,7 @@ class ZambaMambaDecoderLayer(nn.Module):
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
||||
`(batch, sequence_length)` where padding elements are indicated by 0.
|
||||
past_key_value (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states
|
||||
past_key_value (`ZambaHybridDynamicCache`, *optional*): cached past key and value projection states
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
@ -923,7 +741,7 @@ class ZambaMambaDecoderLayer(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
class HybridLayer(nn.Module):
|
||||
class ZambaHybridLayer(nn.Module):
|
||||
def __init__(self, shared_transf: ZambaAttentionDecoderLayer, linear: nn.Linear, mamba: ZambaMambaDecoderLayer):
|
||||
super().__init__()
|
||||
self.shared_transf = shared_transf
|
||||
@ -938,7 +756,7 @@ class HybridLayer(nn.Module):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
causal_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
|
||||
past_key_value: Optional[ZambaHybridDynamicCache] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
@ -951,7 +769,7 @@ class HybridLayer(nn.Module):
|
||||
layer_idx (`int`): layer number.
|
||||
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
||||
`(batch, sequence_length)` where padding elements are indicated by 0.
|
||||
past_key_value (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states
|
||||
past_key_value (`ZambaHybridDynamicCache`, *optional*): cached past key and value projection states
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
@ -1027,7 +845,7 @@ class ZambaPreTrainedModel(PreTrainedModel):
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_supports_flash_attn_2 = False
|
||||
_supports_sdpa = False
|
||||
_supports_cache_class = True # Note: only supports HybridMambaAttentionDynamicCache
|
||||
_supports_cache_class = True # Note: only supports ZambaHybridDynamicCache
|
||||
_is_stateful = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
@ -1121,14 +939,14 @@ ZAMBA_INPUTS_DOCSTRING = r"""
|
||||
config.n_positions - 1]`.
|
||||
|
||||
[What are position IDs?](../glossary#position-ids)
|
||||
past_key_values (`HybridMambaAttentionDynamicCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
A HybridMambaAttentionDynamicCache object containing pre-computed hidden-states (keys and values in the
|
||||
past_key_values (`ZambaHybridDynamicCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
A ZambaHybridDynamicCache object containing pre-computed hidden-states (keys and values in the
|
||||
self-attention blocks and convolution and ssm states in the mamba blocks) that can be used (see
|
||||
`past_key_values` input) to speed up sequential decoding.
|
||||
Key and value cache tensors have shape `(batch_size, num_heads, seq_len, head_dim)`.
|
||||
Convolution and ssm states tensors have shape `(batch_size, d_inner, d_conv)` and
|
||||
`(batch_size, d_inner, d_state)` respectively.
|
||||
See the `HybridMambaAttentionDynamicCache` class for more details.
|
||||
See the `ZambaHybridDynamicCache` class for more details.
|
||||
|
||||
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that
|
||||
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
||||
@ -1202,7 +1020,7 @@ class ZambaModel(ZambaPreTrainedModel):
|
||||
"shared_transf.pre_ff_layernorm.weight",
|
||||
]
|
||||
self._tied_weights_keys = [*self._tied_weights_keys, *[prefix_name + key for key in tied_keys]]
|
||||
layers.append(HybridLayer(block, next(linear_layers), next(mamba_layers)))
|
||||
layers.append(ZambaHybridLayer(block, next(linear_layers), next(mamba_layers)))
|
||||
else:
|
||||
layers.append(next(mamba_layers))
|
||||
self.layers = nn.ModuleList(layers)
|
||||
@ -1226,7 +1044,7 @@ class ZambaModel(ZambaPreTrainedModel):
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[HybridMambaAttentionDynamicCache] = None,
|
||||
past_key_values: Optional[ZambaHybridDynamicCache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
@ -1263,7 +1081,7 @@ class ZambaModel(ZambaPreTrainedModel):
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
logger.warning_once(
|
||||
"Zamba requires an initialized `HybridMambaAttentionDynamicCache` to return a cache. None was "
|
||||
"Zamba requires an initialized `ZambaHybridDynamicCache` to return a cache. None was "
|
||||
"provided, so no cache will be returned."
|
||||
)
|
||||
|
||||
@ -1324,17 +1142,13 @@ class ZambaModel(ZambaPreTrainedModel):
|
||||
if past_key_values and not past_key_values.has_previous_state:
|
||||
past_key_values.has_previous_state = True
|
||||
|
||||
next_cache = None if not use_cache else past_key_values
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
output = BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
past_key_values=past_key_values if use_cache else None,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
return output if return_dict else output.to_tuple()
|
||||
|
||||
# Copied from transformers.models.jamba.modeling_jamba.JambaModel._update_causal_mask
|
||||
def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
|
||||
@ -1410,7 +1224,7 @@ class ZambaForCausalLM(ZambaPreTrainedModel, GenerationMixin):
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[HybridMambaAttentionDynamicCache] = None,
|
||||
past_key_values: Optional[ZambaHybridDynamicCache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
@ -1504,7 +1318,7 @@ class ZambaForCausalLM(ZambaPreTrainedModel, GenerationMixin):
|
||||
use_cache=True,
|
||||
**kwargs,
|
||||
):
|
||||
# Overwitten -- has a unique cache type, `HybridMambaAttentionDynamicCache`
|
||||
# Overwitten -- has a unique cache type, `ZambaHybridDynamicCache`
|
||||
|
||||
empty_past_kv = past_key_values is None
|
||||
|
||||
@ -1518,7 +1332,7 @@ class ZambaForCausalLM(ZambaPreTrainedModel, GenerationMixin):
|
||||
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||
input_ids = input_ids[:, cache_position]
|
||||
else:
|
||||
past_key_values = HybridMambaAttentionDynamicCache(
|
||||
past_key_values = ZambaHybridDynamicCache(
|
||||
self.config, input_ids.shape[0], dtype=self.dtype, device=self.device
|
||||
)
|
||||
|
||||
|
@ -46,7 +46,7 @@ if is_torch_available():
|
||||
ZambaModel,
|
||||
)
|
||||
from transformers.models.zamba.modeling_zamba import (
|
||||
HybridMambaAttentionDynamicCache,
|
||||
ZambaHybridDynamicCache,
|
||||
)
|
||||
|
||||
|
||||
@ -215,9 +215,7 @@ class ZambaModelTester:
|
||||
|
||||
# first forward pass
|
||||
# Attention: Zamba needs the cache to be initialized to return a cache!
|
||||
past_key_values = HybridMambaAttentionDynamicCache(
|
||||
config, input_ids.shape[0], model.dtype, device=model.device
|
||||
)
|
||||
past_key_values = ZambaHybridDynamicCache(config, input_ids.shape[0], model.dtype, device=model.device)
|
||||
outputs = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
|
Loading…
Reference in New Issue
Block a user