mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
Fix RMSNormGated in Zamba2 (#35943)
* First commit
* Finish model implementation
* First commit
* Finish model implementation
* Register zamba2
* generated modeling and configuration
* generated modeling and configuration
* added hybrid cache
* fix attention_mask in mamba
* dropped unused loras
* fix flash2
* config docstrings
* fix config and fwd pass
* make fixup fixes
* text_modeling_zamba2
* small fixes
* make fixup fixes
* Fix modular model converter
* added inheritances in modular, renamed zamba cache
* modular rebase
* new modular conversion
* fix generated modeling file
* fixed import for Zamba2RMSNormGated
* modular file cleanup
* make fixup and model tests
* dropped inheritance for Zamba2PreTrainedModel
* make fixup and unit tests
* Add inheritance of rope from GemmaRotaryEmbedding
* moved rope to model init
* drop del self.self_attn and del self.feed_forward
* fix tests
* renamed lora -> adapter
* rewrote adapter implementation
* fixed tests
* Fix torch_forward in mamba2 layer
* Fix torch_forward in mamba2 layer
* Fix torch_forward in mamba2 layer
* Dropped adapter in-place sum
* removed rope from attention init
* updated rope
* created get_layers method
* make fixup fix
* make fixup fixes
* make fixup fixes
* update to new attention standard
* update to new attention standard
* make fixup fixes
* minor fixes
* cache_position
* removed cache_position postion_ids use_cache
* remove config from modular
* removed config from modular (2)
* import apply_rotary_pos_emb from llama
* fixed rope_kwargs
* Instantiate cache in Zamba2Model
* fix cache
* fix @slow decorator
* small fix in modular file
* Update docs/source/en/model_doc/zamba2.md
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
* several minor fixes
* inherit mamba2decoder fwd and drop position_ids in mamba
* removed docstrings from modular
* reinstate zamba2 attention decoder fwd
* use regex for tied keys
* Revert "use regex for tied keys"
This reverts commit 9007a522b1
.
* use regex for tied keys
* add cpu to slow forward tests
* dropped config.use_shared_mlp_adapter
* Update docs/source/en/model_doc/zamba2.md
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
* re-convert from modular
* extended Zamba2RMSNormGated to n_groups>1
* removed einops import
* set _supports_sdpa = True
* add use_mem_eff_path flag for fused mamba2 fwd
* added docstring for use_mem_eff_ath flag
---------
Co-authored-by: root <root@node-2.us-southcentral1-a.compute.internal>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
parent
bc9a6d8302
commit
a93b80588b
@ -64,6 +64,8 @@ class Zamba2Config(PretrainedConfig):
|
||||
Whether or not to use bias in the convolution layer of the mixer block.
|
||||
chunk_size (`int`, *optional*, defaults to 256):
|
||||
Size of the chunks that will comprise the sequence.
|
||||
use_mem_eff_path (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to use the fused conv1d and scan in mamba2 layers.
|
||||
add_bias_linear (`bool`, *optional*, defaults to `False`):
|
||||
Flag indicating whether or not to use bias in various layers
|
||||
intermediate_size (`int`, *optional*, defaults to 4 * hidden_size):
|
||||
@ -143,6 +145,7 @@ class Zamba2Config(PretrainedConfig):
|
||||
n_mamba_heads=8,
|
||||
use_conv_bias=True,
|
||||
chunk_size=256,
|
||||
use_mem_eff_path=False,
|
||||
add_bias_linear=False,
|
||||
intermediate_size=None,
|
||||
hidden_act="gelu",
|
||||
@ -231,6 +234,7 @@ class Zamba2Config(PretrainedConfig):
|
||||
self.use_cache = use_cache
|
||||
self.num_logits_to_keep = num_logits_to_keep
|
||||
self.hybrid_layer_ids = [index for index, type in enumerate(self.layers_block_type) if type == "hybrid"]
|
||||
self.use_mem_eff_path = use_mem_eff_path
|
||||
|
||||
|
||||
__all__ = ["Zamba2Config"]
|
||||
|
@ -62,20 +62,23 @@ _CONFIG_FOR_DOC = "Zyphra/Zamba2-2.7B"
|
||||
|
||||
|
||||
class Zamba2RMSNormGated(torch.nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
def __init__(self, hidden_size, group_size, eps=1e-6):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
self.group_size = group_size
|
||||
|
||||
def forward(self, hidden_states, gate=None):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
|
||||
if gate is not None:
|
||||
hidden_states = hidden_states * nn.functional.silu(gate.to(torch.float32))
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
|
||||
*prefix_dims, last_dim = hidden_states.shape
|
||||
group_count = last_dim // self.group_size
|
||||
hidden_states_group = hidden_states.view(*prefix_dims, group_count, self.group_size)
|
||||
variance = hidden_states_group.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states_group = hidden_states_group * torch.rsqrt(variance + self.variance_epsilon)
|
||||
hidden_states = hidden_states_group.view(*prefix_dims, group_count * self.group_size)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
|
||||
@ -563,6 +566,7 @@ class Zamba2MambaMixer(nn.Module):
|
||||
self.use_conv_bias = config.use_conv_bias
|
||||
self.activation = "silu"
|
||||
self.act = nn.SiLU()
|
||||
self.use_mem_eff_path = config.use_mem_eff_path
|
||||
|
||||
self.n_groups = config.mamba_ngroups
|
||||
self.head_dim = config.mamba_headdim
|
||||
@ -601,7 +605,9 @@ class Zamba2MambaMixer(nn.Module):
|
||||
A = torch.arange(1, self.num_heads + 1)
|
||||
self.A_log = nn.Parameter(torch.log(A))
|
||||
self.A_log._no_weight_decay = True
|
||||
self.norm = Zamba2RMSNormGated(self.intermediate_size, eps=1e-5)
|
||||
self.norm = Zamba2RMSNormGated(
|
||||
self.intermediate_size, group_size=self.intermediate_size // self.n_groups, eps=1e-5
|
||||
)
|
||||
self.D = nn.Parameter(torch.ones(self.num_heads))
|
||||
self.D._no_weight_decay = True
|
||||
|
||||
@ -685,7 +691,7 @@ class Zamba2MambaMixer(nn.Module):
|
||||
else:
|
||||
input_not_masked = True
|
||||
|
||||
if self.training and cache_params is None and input_not_masked:
|
||||
if self.use_mem_eff_path and self.training and cache_params is None and input_not_masked:
|
||||
out, ssm_state = mamba_split_conv1d_scan_combined(
|
||||
projected_states,
|
||||
self.conv1d.weight.squeeze(1),
|
||||
@ -1227,7 +1233,7 @@ class Zamba2PreTrainedModel(PreTrainedModel):
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_flex_attn = True
|
||||
_supports_sdpa = False
|
||||
_supports_sdpa = True
|
||||
_supports_cache_class = True # Note: only supports Zamba2HybridDynamicCache
|
||||
_is_stateful = True
|
||||
|
||||
|
@ -35,7 +35,7 @@ from ...utils.import_utils import (
|
||||
is_mamba_ssm_available,
|
||||
)
|
||||
from ..llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
|
||||
from ..mamba2.modeling_mamba2 import MambaRMSNormGated, pad_tensor_by_size, reshape_into_chunks, segment_sum
|
||||
from ..mamba2.modeling_mamba2 import pad_tensor_by_size, reshape_into_chunks, segment_sum
|
||||
from ..zamba.modeling_zamba import (
|
||||
ZambaAttention,
|
||||
ZambaAttentionDecoderLayer,
|
||||
@ -70,8 +70,25 @@ _CONFIG_FOR_DOC = "Zyphra/Zamba2-2.7B"
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Zamba2RMSNormGated(MambaRMSNormGated):
|
||||
pass
|
||||
class Zamba2RMSNormGated(torch.nn.Module):
|
||||
def __init__(self, hidden_size, group_size, eps=1e-6):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
self.group_size = group_size
|
||||
|
||||
def forward(self, hidden_states, gate=None):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
if gate is not None:
|
||||
hidden_states = hidden_states * nn.functional.silu(gate.to(torch.float32))
|
||||
*prefix_dims, last_dim = hidden_states.shape
|
||||
group_count = last_dim // self.group_size
|
||||
hidden_states_group = hidden_states.view(*prefix_dims, group_count, self.group_size)
|
||||
variance = hidden_states_group.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states_group = hidden_states_group * torch.rsqrt(variance + self.variance_epsilon)
|
||||
hidden_states = hidden_states_group.view(*prefix_dims, group_count * self.group_size)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
|
||||
class Zamba2RMSNorm(ZambaRMSNorm):
|
||||
@ -296,6 +313,7 @@ class Zamba2MambaMixer(nn.Module):
|
||||
self.use_conv_bias = config.use_conv_bias
|
||||
self.activation = "silu"
|
||||
self.act = nn.SiLU()
|
||||
self.use_mem_eff_path = config.use_mem_eff_path
|
||||
|
||||
self.n_groups = config.mamba_ngroups
|
||||
self.head_dim = config.mamba_headdim
|
||||
@ -334,7 +352,9 @@ class Zamba2MambaMixer(nn.Module):
|
||||
A = torch.arange(1, self.num_heads + 1)
|
||||
self.A_log = nn.Parameter(torch.log(A))
|
||||
self.A_log._no_weight_decay = True
|
||||
self.norm = Zamba2RMSNormGated(self.intermediate_size, eps=1e-5)
|
||||
self.norm = Zamba2RMSNormGated(
|
||||
self.intermediate_size, group_size=self.intermediate_size // self.n_groups, eps=1e-5
|
||||
)
|
||||
self.D = nn.Parameter(torch.ones(self.num_heads))
|
||||
self.D._no_weight_decay = True
|
||||
|
||||
@ -418,7 +438,7 @@ class Zamba2MambaMixer(nn.Module):
|
||||
else:
|
||||
input_not_masked = True
|
||||
|
||||
if self.training and cache_params is None and input_not_masked:
|
||||
if self.use_mem_eff_path and self.training and cache_params is None and input_not_masked:
|
||||
out, ssm_state = mamba_split_conv1d_scan_combined(
|
||||
projected_states,
|
||||
self.conv1d.weight.squeeze(1),
|
||||
@ -896,7 +916,7 @@ class Zamba2PreTrainedModel(PreTrainedModel):
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_flex_attn = True
|
||||
_supports_sdpa = False
|
||||
_supports_sdpa = True
|
||||
_supports_cache_class = True # Note: only supports Zamba2HybridDynamicCache
|
||||
_is_stateful = True
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user