Add config validation and style tweaks (#37589)

* Add config validation and style tweaks

* Fix style issues

* Fix style issues

* style

* Small fixes for copy/paste errors

---------

Co-authored-by: Cyrile <cyrile.delestre@arkea.com>
This commit is contained in:
Kirire 2025-05-14 14:22:10 +02:00 committed by GitHub
parent 1b00966395
commit 935bbbc711
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 44 additions and 21 deletions

View File

@ -140,6 +140,13 @@ class Mamba2Config(PretrainedConfig):
tie_word_embeddings=False,
**kwargs,
):
if (hidden_size * expand) != (num_heads * head_dim):
raise ValueError(
"Inconsistent configuration: hidden_size * expand "
f"({hidden_size * expand}) must equal num_heads * head_dim "
f"({num_heads * head_dim})."
)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.state_size = state_size

View File

@ -21,7 +21,6 @@ from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...generation import GenerationMixin
@ -457,13 +456,19 @@ class Mamba2Mixer(nn.Module):
return out
# fmt: off
def torch_forward(self, input_states, cache_params: Optional[Mamba2Cache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.Tensor]=None):
batch_size, seq_len, _ = input_states.shape
dtype = input_states.dtype
def torch_forward(
self,
hidden_states: torch.Tensor,
cache_params: Optional[Mamba2Cache]=None,
cache_position:Optional[torch.LongTensor]=None,
attention_mask: Optional[torch.Tensor]=None
):
batch_size, seq_len, _ = hidden_states.shape
dtype = hidden_states.dtype
# 1. Gated MLP's linear projection
input_states = apply_mask_to_padding_states(input_states, attention_mask)
projected_states = self.in_proj(input_states)
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
projected_states = self.in_proj(hidden_states)
d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.ssm_state_size-self.num_heads) // 2
_, _, gate, hidden_states_B_C, dt = projected_states.split(
[d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
@ -657,11 +662,6 @@ class Mamba2Mixer(nn.Module):
):
if is_fast_path_available and "cuda" in self.in_proj.weight.device.type:
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
dtype = hidden_states.dtype
if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
# tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask)
@ -1018,7 +1018,7 @@ class Mamba2ForCausalLM(Mamba2PreTrainedModel, GenerationMixin):
use_cache: Optional[bool] = None,
cache_position: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs, # for now we need this for generation
**kwargs, # for now we need this for generation and loss_function
) -> Union[Tuple, Mamba2CausalLMOutput]:
r"""
cache_params (`Mamba2Cache`, *optional*):
@ -1052,14 +1052,7 @@ class Mamba2ForCausalLM(Mamba2PreTrainedModel, GenerationMixin):
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
if not return_dict:
output = (logits,) + mamba2_outputs[1:]

View File

@ -44,6 +44,29 @@ if is_torch_available():
from transformers.models.mamba2.modeling_mamba2 import Mamba2Cache, Mamba2Mixer
class Mamba2ConfigTester(ConfigTester):
def _create_config(self, hidden_size: int, num_heads: int, expand: int, head_dim: int):
_input_dict = self.inputs_dict.copy()
_input_dict["hidden_size"] = hidden_size
_input_dict["num_heads"] = num_heads
_input_dict["expand"] = expand
_input_dict["head_dim"] = head_dim
return self.config_class(**_input_dict)
def test_hidden_size_compatibility(self):
self._create_config(hidden_size=2, num_heads=2, expand=2, head_dim=2)
self._create_config(hidden_size=4, num_heads=4, expand=2, head_dim=2)
self._create_config(hidden_size=2, num_heads=4, expand=4, head_dim=2)
with self.parent.assertRaises(ValueError):
self._create_config(hidden_size=2, num_heads=4, expand=2, head_dim=4)
with self.parent.assertRaises(ValueError):
self._create_config(hidden_size=4, num_heads=2, expand=4, head_dim=2)
def run_common_tests(self):
self.test_hidden_size_compatibility()
return super().run_common_tests()
class Mamba2ModelTester:
def __init__(
self,
@ -233,7 +256,7 @@ class Mamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
def setUp(self):
self.model_tester = Mamba2ModelTester(self)
self.config_tester = ConfigTester(
self.config_tester = Mamba2ConfigTester(
self, config_class=Mamba2Config, n_embd=37, common_properties=["hidden_size", "num_hidden_layers"]
)