Fix RMSNormGated in Zamba2 (#35943)

* First commit

* Finish model implementation

* First commit

* Finish model implementation

* Register zamba2

* generated modeling and configuration

* generated modeling and configuration

* added hybrid cache

* fix attention_mask in mamba

* dropped unused loras

* fix flash2

* config docstrings

* fix config and fwd pass

* make fixup fixes

* text_modeling_zamba2

* small fixes

* make fixup fixes

* Fix modular model converter

* added inheritances in modular, renamed zamba cache

* modular rebase

* new modular conversion

* fix generated modeling file

* fixed import for Zamba2RMSNormGated

* modular file cleanup

* make fixup and model tests

* dropped inheritance for Zamba2PreTrainedModel

* make fixup and unit tests

* Add inheritance of rope from GemmaRotaryEmbedding

* moved rope to model init

* drop del self.self_attn and del self.feed_forward

* fix tests

* renamed lora -> adapter

* rewrote adapter implementation

* fixed tests

* Fix torch_forward in mamba2 layer

* Fix torch_forward in mamba2 layer

* Fix torch_forward in mamba2 layer

* Dropped adapter in-place sum

* removed rope from attention init

* updated rope

* created get_layers method

* make fixup fix

* make fixup fixes

* make fixup fixes

* update to new attention standard

* update to new attention standard

* make fixup fixes

* minor fixes

* cache_position

* removed cache_position postion_ids use_cache

* remove config from modular

* removed config from modular (2)

* import apply_rotary_pos_emb from llama

* fixed rope_kwargs

* Instantiate cache in Zamba2Model

* fix cache

* fix @slow decorator

* small fix in modular file

* Update docs/source/en/model_doc/zamba2.md

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* several minor fixes

* inherit mamba2decoder fwd and drop position_ids in mamba

* removed docstrings from modular

* reinstate zamba2 attention decoder fwd

* use regex for tied keys

* Revert "use regex for tied keys"

This reverts commit 9007a522b1.

* use regex for tied keys

* add cpu to slow forward tests

* dropped config.use_shared_mlp_adapter

* Update docs/source/en/model_doc/zamba2.md

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* re-convert from modular

* extended Zamba2RMSNormGated to n_groups>1

* removed einops import

* set _supports_sdpa = True

* add use_mem_eff_path flag for fused mamba2 fwd

* added docstring for use_mem_eff_ath flag

---------

Co-authored-by: root <root@node-2.us-southcentral1-a.compute.internal>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
pglorio 2025-02-04 05:28:04 -08:00 committed by GitHub
parent bc9a6d8302
commit a93b80588b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 44 additions and 14 deletions

View File

@ -64,6 +64,8 @@ class Zamba2Config(PretrainedConfig):
Whether or not to use bias in the convolution layer of the mixer block.
chunk_size (`int`, *optional*, defaults to 256):
Size of the chunks that will comprise the sequence.
use_mem_eff_path (`bool`, *optional*, defaults to `False`):
Whether or not to use the fused conv1d and scan in mamba2 layers.
add_bias_linear (`bool`, *optional*, defaults to `False`):
Flag indicating whether or not to use bias in various layers
intermediate_size (`int`, *optional*, defaults to 4 * hidden_size):
@ -143,6 +145,7 @@ class Zamba2Config(PretrainedConfig):
n_mamba_heads=8,
use_conv_bias=True,
chunk_size=256,
use_mem_eff_path=False,
add_bias_linear=False,
intermediate_size=None,
hidden_act="gelu",
@ -231,6 +234,7 @@ class Zamba2Config(PretrainedConfig):
self.use_cache = use_cache
self.num_logits_to_keep = num_logits_to_keep
self.hybrid_layer_ids = [index for index, type in enumerate(self.layers_block_type) if type == "hybrid"]
self.use_mem_eff_path = use_mem_eff_path
__all__ = ["Zamba2Config"]

View File

@ -62,20 +62,23 @@ _CONFIG_FOR_DOC = "Zyphra/Zamba2-2.7B"
class Zamba2RMSNormGated(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-6):
def __init__(self, hidden_size, group_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
self.group_size = group_size
def forward(self, hidden_states, gate=None):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
if gate is not None:
hidden_states = hidden_states * nn.functional.silu(gate.to(torch.float32))
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
*prefix_dims, last_dim = hidden_states.shape
group_count = last_dim // self.group_size
hidden_states_group = hidden_states.view(*prefix_dims, group_count, self.group_size)
variance = hidden_states_group.pow(2).mean(-1, keepdim=True)
hidden_states_group = hidden_states_group * torch.rsqrt(variance + self.variance_epsilon)
hidden_states = hidden_states_group.view(*prefix_dims, group_count * self.group_size)
return self.weight * hidden_states.to(input_dtype)
@ -563,6 +566,7 @@ class Zamba2MambaMixer(nn.Module):
self.use_conv_bias = config.use_conv_bias
self.activation = "silu"
self.act = nn.SiLU()
self.use_mem_eff_path = config.use_mem_eff_path
self.n_groups = config.mamba_ngroups
self.head_dim = config.mamba_headdim
@ -601,7 +605,9 @@ class Zamba2MambaMixer(nn.Module):
A = torch.arange(1, self.num_heads + 1)
self.A_log = nn.Parameter(torch.log(A))
self.A_log._no_weight_decay = True
self.norm = Zamba2RMSNormGated(self.intermediate_size, eps=1e-5)
self.norm = Zamba2RMSNormGated(
self.intermediate_size, group_size=self.intermediate_size // self.n_groups, eps=1e-5
)
self.D = nn.Parameter(torch.ones(self.num_heads))
self.D._no_weight_decay = True
@ -685,7 +691,7 @@ class Zamba2MambaMixer(nn.Module):
else:
input_not_masked = True
if self.training and cache_params is None and input_not_masked:
if self.use_mem_eff_path and self.training and cache_params is None and input_not_masked:
out, ssm_state = mamba_split_conv1d_scan_combined(
projected_states,
self.conv1d.weight.squeeze(1),
@ -1227,7 +1233,7 @@ class Zamba2PreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_flex_attn = True
_supports_sdpa = False
_supports_sdpa = True
_supports_cache_class = True # Note: only supports Zamba2HybridDynamicCache
_is_stateful = True

View File

@ -35,7 +35,7 @@ from ...utils.import_utils import (
is_mamba_ssm_available,
)
from ..llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
from ..mamba2.modeling_mamba2 import MambaRMSNormGated, pad_tensor_by_size, reshape_into_chunks, segment_sum
from ..mamba2.modeling_mamba2 import pad_tensor_by_size, reshape_into_chunks, segment_sum
from ..zamba.modeling_zamba import (
ZambaAttention,
ZambaAttentionDecoderLayer,
@ -70,8 +70,25 @@ _CONFIG_FOR_DOC = "Zyphra/Zamba2-2.7B"
logger = logging.get_logger(__name__)
class Zamba2RMSNormGated(MambaRMSNormGated):
pass
class Zamba2RMSNormGated(torch.nn.Module):
def __init__(self, hidden_size, group_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
self.group_size = group_size
def forward(self, hidden_states, gate=None):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
if gate is not None:
hidden_states = hidden_states * nn.functional.silu(gate.to(torch.float32))
*prefix_dims, last_dim = hidden_states.shape
group_count = last_dim // self.group_size
hidden_states_group = hidden_states.view(*prefix_dims, group_count, self.group_size)
variance = hidden_states_group.pow(2).mean(-1, keepdim=True)
hidden_states_group = hidden_states_group * torch.rsqrt(variance + self.variance_epsilon)
hidden_states = hidden_states_group.view(*prefix_dims, group_count * self.group_size)
return self.weight * hidden_states.to(input_dtype)
class Zamba2RMSNorm(ZambaRMSNorm):
@ -296,6 +313,7 @@ class Zamba2MambaMixer(nn.Module):
self.use_conv_bias = config.use_conv_bias
self.activation = "silu"
self.act = nn.SiLU()
self.use_mem_eff_path = config.use_mem_eff_path
self.n_groups = config.mamba_ngroups
self.head_dim = config.mamba_headdim
@ -334,7 +352,9 @@ class Zamba2MambaMixer(nn.Module):
A = torch.arange(1, self.num_heads + 1)
self.A_log = nn.Parameter(torch.log(A))
self.A_log._no_weight_decay = True
self.norm = Zamba2RMSNormGated(self.intermediate_size, eps=1e-5)
self.norm = Zamba2RMSNormGated(
self.intermediate_size, group_size=self.intermediate_size // self.n_groups, eps=1e-5
)
self.D = nn.Parameter(torch.ones(self.num_heads))
self.D._no_weight_decay = True
@ -418,7 +438,7 @@ class Zamba2MambaMixer(nn.Module):
else:
input_not_masked = True
if self.training and cache_params is None and input_not_masked:
if self.use_mem_eff_path and self.training and cache_params is None and input_not_masked:
out, ssm_state = mamba_split_conv1d_scan_combined(
projected_states,
self.conv1d.weight.squeeze(1),
@ -896,7 +916,7 @@ class Zamba2PreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_flex_attn = True
_supports_sdpa = False
_supports_sdpa = True
_supports_cache_class = True # Note: only supports Zamba2HybridDynamicCache
_is_stateful = True