Zamba new attention standard (#35375)

* updated zamba to new attention standard

* make fixup fixes
This commit is contained in:
pglorio 2025-01-06 23:08:45 -10:00 committed by GitHub
parent 12ba96aa3c
commit bd442c6d3a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 104 additions and 292 deletions

View File

@ -20,7 +20,7 @@
"""PyTorch Zamba model."""
import math
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
@ -33,18 +33,18 @@ from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import (
AttentionMaskConverter,
)
from ...modeling_flash_attention_utils import _flash_attention_forward
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
)
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
)
@ -113,7 +113,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
class HybridMambaAttentionDynamicCache(DynamicCache):
class ZambaHybridDynamicCache(DynamicCache):
"""
A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache
(which has a constant shape regardless of seq_len).
@ -131,9 +131,9 @@ class HybridMambaAttentionDynamicCache(DynamicCache):
self.dtype = dtype
self.layers_block_type = config.layers_block_type
self.has_previous_state = False # only used by mamba
intermediate_size = config.mamba_expand * config.hidden_size
ssm_state_size = config.mamba_d_state
conv_kernel_size = config.mamba_d_conv
self.intermediate_size = config.mamba_expand * config.hidden_size
self.ssm_state_size = config.mamba_d_state
self.conv_kernel_size = config.mamba_d_conv
self.n_mamba_heads = config.n_mamba_heads
self.conv_states = []
self.ssm_states = []
@ -143,9 +143,14 @@ class HybridMambaAttentionDynamicCache(DynamicCache):
self._buffers = {}
for i in range(config.num_hidden_layers):
self.conv_states += [
torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype)
torch.zeros(batch_size, self.intermediate_size, self.conv_kernel_size, device=device, dtype=dtype)
]
cache_shape = (batch_size, self.n_mamba_heads, intermediate_size // self.n_mamba_heads, ssm_state_size)
cache_shape = (
batch_size,
self.n_mamba_heads,
self.intermediate_size // self.n_mamba_heads,
self.ssm_state_size,
)
self.ssm_states += [torch.zeros(cache_shape, device=device, dtype=dtype)]
if self.layers_block_type[i] == "hybrid":
self.transformer_layers.append(i)
@ -194,14 +199,38 @@ class HybridMambaAttentionDynamicCache(DynamicCache):
return 0
return self.key_cache[layer_idx].shape[-2]
# Copied from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache.to_legacy_cache
def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.")
raise NotImplementedError("ZambaHybridDynamicCache does not have a legacy cache equivalent.")
@classmethod
# Copied from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache.from_legacy_cache
def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.")
raise NotImplementedError("ZambaHybridDynamicCache does not have a legacy cache equivalent.")
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class ZambaAttention(nn.Module):
@ -218,277 +247,67 @@ class ZambaAttention(nn.Module):
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim/2)
"""
def __init__(self, config: ZambaConfig, layer_idx: Optional[int] = None):
def __init__(self, config: ZambaConfig, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.hidden_size = config.hidden_size
self.attention_hidden_size = config.attention_hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = config.attention_head_dim
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.scaling = (self.head_dim / 2) ** -0.5
self.is_causal = True
self.attention_dropout = config.attention_dropout
if (self.head_dim * self.num_heads) != self.attention_hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
self.q_proj = nn.Linear(self.attention_hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.attention_hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.attention_hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
self.q_proj = nn.Linear(config.attention_hidden_size, config.num_attention_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(config.attention_hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(config.attention_hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
def forward(
self,
hidden_states: torch.Tensor,
layer_idx: int,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
output_attentions: bool = False,
use_cache: bool = False,
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[ZambaHybridDynamicCache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
if past_key_value is not None:
key_states, value_states = past_key_value.update(key_states, value_states, layer_idx)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim / 2)
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.attention_hidden_size)
attn_output = attn_output
attn_output = self.o_proj(attn_output)
attn_output = attn_output
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
# Adapted from transformers.models.mistral.modeling_mistral.MistralAttention:
# Added softmax_scale = 1 / (query_states.shape[-1]/2)**0.5 to the arguments of self._flash_attention_forward
# dropped use_sliding_windows from the arguments of self._flash_attention_forward
class ZambaFlashAttention2(ZambaAttention):
"""
Zamba flash attention module. This module inherits from `ZambaAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def forward(
self,
hidden_states: torch.Tensor,
layer_idx: int,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
):
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
# therefore we just need to keep the original shape
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
if past_key_value is not None:
key_states, value_states = past_key_value.update(key_states, value_states, layer_idx)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
dropout_rate = 0.0 if not self.training else self.attention_dropout
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in float16 just to be sure everything works as expected.
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
target_dtype = self.q_proj.weight.dtype
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
# Reashape to the expected shape for Flash Attention
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
softmax_scale = 1 / math.sqrt(self.head_dim / 2)
attn_output = _flash_attention_forward(
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=dropout_rate,
softmax_scale=softmax_scale,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape(bsz, q_len, self.attention_hidden_size).contiguous()
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
# Adapted from transformers.models.mistral.modeling_mistral.MistralAttention:
# added scale = 1 / (query_states.shape[-1]/2)**0.5 to the arguments of torch.nn.functional.scaled_dot_product_attention
class ZambaSdpaAttention(ZambaAttention):
"""
Zamba attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
`ZambaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
SDPA API.
"""
def forward(
self,
hidden_states: torch.Tensor,
layer_idx: int,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"ZambaModel is using ZambaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
if past_key_value is not None:
key_states, value_states = past_key_value.update(key_states, value_states, layer_idx)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
causal_mask = attention_mask
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and attention_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
softmax_scale = 1 / math.sqrt(self.head_dim / 2)
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
is_causal=self.is_causal and attention_mask is None and q_len > 1,
scale=softmax_scale,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, self.attention_hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
ZAMBA_ATTENTION_CLASSES = {
"eager": ZambaAttention,
"flash_attention_2": ZambaFlashAttention2,
"sdpa": ZambaSdpaAttention,
}
return attn_output, attn_weights
class ZambaMambaMixer(nn.Module):
@ -568,7 +387,7 @@ class ZambaMambaMixer(nn.Module):
)
def cuda_kernels_forward(
self, hidden_states: torch.Tensor, cache_params: HybridMambaAttentionDynamicCache = None, attention_mask=None
self, hidden_states: torch.Tensor, cache_params: ZambaHybridDynamicCache = None, attention_mask=None
):
batch_size, seq_len, _ = hidden_states.shape
use_precomputed_states = cache_params is not None and cache_params.has_previous_state and seq_len == 1
@ -664,7 +483,7 @@ class ZambaMambaMixer(nn.Module):
contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))
return contextualized_states
def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCache = None, attention_mask=None):
def slow_forward(self, input_states, cache_params: ZambaHybridDynamicCache = None, attention_mask=None):
batch_size, seq_len, _ = input_states.shape
dtype = input_states.dtype
# 1. Gated linear projection
@ -675,7 +494,7 @@ class ZambaMambaMixer(nn.Module):
gate = gate.squeeze(2)
gate = gate.reshape(batch_size, self.n_mamba_heads, -1, seq_len).transpose(0, 1)
use_cache = isinstance(cache_params, HybridMambaAttentionDynamicCache)
use_cache = isinstance(cache_params, ZambaHybridDynamicCache)
# 2. Convolution sequence transformation
if use_cache and cache_params.ssm_states[self.layer_idx].shape[0] == batch_size:
if self.training:
@ -757,7 +576,7 @@ class ZambaMambaMixer(nn.Module):
)
return contextualized_states
def forward(self, hidden_states, cache_params: HybridMambaAttentionDynamicCache = None, attention_mask=None):
def forward(self, hidden_states, cache_params: ZambaHybridDynamicCache = None, attention_mask=None):
if self.use_fast_kernels:
if not is_fast_path_available or "cuda" not in self.x_proj_weight.device.type:
raise ValueError(
@ -789,7 +608,7 @@ class ZambaMLP(nn.Module):
class ZambaAttentionDecoderLayer(nn.Module):
def __init__(self, config: ZambaConfig, layer_idx: Optional[int] = None):
super().__init__()
self.self_attn = ZAMBA_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
self.self_attn = ZambaAttention(config, layer_idx)
self.feed_forward = ZambaMLP(config)
self.input_layernorm = ZambaRMSNorm(config.attention_hidden_size, eps=config.rms_norm_eps)
@ -802,11 +621,11 @@ class ZambaAttentionDecoderLayer(nn.Module):
layer_idx: int,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
past_key_value: Optional[ZambaHybridDynamicCache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
@ -815,9 +634,11 @@ class ZambaAttentionDecoderLayer(nn.Module):
This is concatenated with `hidden_states` (which is the output of the previous (mamba) layer). The
concatenated tensor is then used as input of the pre-attention RMSNorm
(see fig. 2 in https://arxiv.org/pdf/2405.16712).
layer_idx (`int`): layer_idx in the forward pass. Used to distinguish Zamba's tied transformer layers.
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, sequence_length)` where padding elements are indicated by 0.
past_key_value (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states
position_ids (`torch.LongTensor`, *optional*): token positions of shape `(batch, seq_len)`. Used for positional encodings.
past_key_value (`ZambaHybridDynamicCache`, *optional*): cached past key and value projection states
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
@ -829,7 +650,7 @@ class ZambaAttentionDecoderLayer(nn.Module):
"""
hidden_states = torch.concatenate([hidden_states, original_hidden_states], dim=-1)
hidden_states = self.input_layernorm(hidden_states)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
layer_idx=layer_idx,
attention_mask=attention_mask,
@ -849,9 +670,6 @@ class ZambaAttentionDecoderLayer(nn.Module):
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
@ -870,7 +688,7 @@ class ZambaMambaDecoderLayer(nn.Module):
attention_mask: Optional[torch.Tensor] = None,
causal_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
past_key_value: Optional[ZambaHybridDynamicCache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
@ -881,7 +699,7 @@ class ZambaMambaDecoderLayer(nn.Module):
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, sequence_length)` where padding elements are indicated by 0.
past_key_value (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states
past_key_value (`ZambaHybridDynamicCache`, *optional*): cached past key and value projection states
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
@ -923,7 +741,7 @@ class ZambaMambaDecoderLayer(nn.Module):
return outputs
class HybridLayer(nn.Module):
class ZambaHybridLayer(nn.Module):
def __init__(self, shared_transf: ZambaAttentionDecoderLayer, linear: nn.Linear, mamba: ZambaMambaDecoderLayer):
super().__init__()
self.shared_transf = shared_transf
@ -938,7 +756,7 @@ class HybridLayer(nn.Module):
attention_mask: Optional[torch.Tensor] = None,
causal_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
past_key_value: Optional[ZambaHybridDynamicCache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
@ -951,7 +769,7 @@ class HybridLayer(nn.Module):
layer_idx (`int`): layer number.
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, sequence_length)` where padding elements are indicated by 0.
past_key_value (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states
past_key_value (`ZambaHybridDynamicCache`, *optional*): cached past key and value projection states
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
@ -1027,7 +845,7 @@ class ZambaPreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = False
_supports_sdpa = False
_supports_cache_class = True # Note: only supports HybridMambaAttentionDynamicCache
_supports_cache_class = True # Note: only supports ZambaHybridDynamicCache
_is_stateful = True
def _init_weights(self, module):
@ -1121,14 +939,14 @@ ZAMBA_INPUTS_DOCSTRING = r"""
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
past_key_values (`HybridMambaAttentionDynamicCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
A HybridMambaAttentionDynamicCache object containing pre-computed hidden-states (keys and values in the
past_key_values (`ZambaHybridDynamicCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
A ZambaHybridDynamicCache object containing pre-computed hidden-states (keys and values in the
self-attention blocks and convolution and ssm states in the mamba blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
Key and value cache tensors have shape `(batch_size, num_heads, seq_len, head_dim)`.
Convolution and ssm states tensors have shape `(batch_size, d_inner, d_conv)` and
`(batch_size, d_inner, d_state)` respectively.
See the `HybridMambaAttentionDynamicCache` class for more details.
See the `ZambaHybridDynamicCache` class for more details.
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
@ -1202,7 +1020,7 @@ class ZambaModel(ZambaPreTrainedModel):
"shared_transf.pre_ff_layernorm.weight",
]
self._tied_weights_keys = [*self._tied_weights_keys, *[prefix_name + key for key in tied_keys]]
layers.append(HybridLayer(block, next(linear_layers), next(mamba_layers)))
layers.append(ZambaHybridLayer(block, next(linear_layers), next(mamba_layers)))
else:
layers.append(next(mamba_layers))
self.layers = nn.ModuleList(layers)
@ -1226,7 +1044,7 @@ class ZambaModel(ZambaPreTrainedModel):
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[HybridMambaAttentionDynamicCache] = None,
past_key_values: Optional[ZambaHybridDynamicCache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
@ -1263,7 +1081,7 @@ class ZambaModel(ZambaPreTrainedModel):
if use_cache and past_key_values is None:
logger.warning_once(
"Zamba requires an initialized `HybridMambaAttentionDynamicCache` to return a cache. None was "
"Zamba requires an initialized `ZambaHybridDynamicCache` to return a cache. None was "
"provided, so no cache will be returned."
)
@ -1324,17 +1142,13 @@ class ZambaModel(ZambaPreTrainedModel):
if past_key_values and not past_key_values.has_previous_state:
past_key_values.has_previous_state = True
next_cache = None if not use_cache else past_key_values
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
output = BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
past_key_values=past_key_values if use_cache else None,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
return output if return_dict else output.to_tuple()
# Copied from transformers.models.jamba.modeling_jamba.JambaModel._update_causal_mask
def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
@ -1410,7 +1224,7 @@ class ZambaForCausalLM(ZambaPreTrainedModel, GenerationMixin):
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[HybridMambaAttentionDynamicCache] = None,
past_key_values: Optional[ZambaHybridDynamicCache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
@ -1504,7 +1318,7 @@ class ZambaForCausalLM(ZambaPreTrainedModel, GenerationMixin):
use_cache=True,
**kwargs,
):
# Overwitten -- has a unique cache type, `HybridMambaAttentionDynamicCache`
# Overwitten -- has a unique cache type, `ZambaHybridDynamicCache`
empty_past_kv = past_key_values is None
@ -1518,7 +1332,7 @@ class ZambaForCausalLM(ZambaPreTrainedModel, GenerationMixin):
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]
else:
past_key_values = HybridMambaAttentionDynamicCache(
past_key_values = ZambaHybridDynamicCache(
self.config, input_ids.shape[0], dtype=self.dtype, device=self.device
)

View File

@ -46,7 +46,7 @@ if is_torch_available():
ZambaModel,
)
from transformers.models.zamba.modeling_zamba import (
HybridMambaAttentionDynamicCache,
ZambaHybridDynamicCache,
)
@ -215,9 +215,7 @@ class ZambaModelTester:
# first forward pass
# Attention: Zamba needs the cache to be initialized to return a cache!
past_key_values = HybridMambaAttentionDynamicCache(
config, input_ids.shape[0], model.dtype, device=model.device
)
past_key_values = ZambaHybridDynamicCache(config, input_ids.shape[0], model.dtype, device=model.device)
outputs = model(
input_ids,
attention_mask=input_mask,