Add config validation and style tweaks (#37589)

* Add config validation and style tweaks

* Fix style issues

* Fix style issues

* style

* Small fixes for copy/paste errors

---------

Co-authored-by: Cyrile <cyrile.delestre@arkea.com>
This commit is contained in:
Kirire 2025-05-14 14:22:10 +02:00 committed by GitHub
parent 1b00966395
commit 935bbbc711
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 44 additions and 21 deletions

View File

@ -140,6 +140,13 @@ class Mamba2Config(PretrainedConfig):
tie_word_embeddings=False, tie_word_embeddings=False,
**kwargs, **kwargs,
): ):
if (hidden_size * expand) != (num_heads * head_dim):
raise ValueError(
"Inconsistent configuration: hidden_size * expand "
f"({hidden_size * expand}) must equal num_heads * head_dim "
f"({num_heads * head_dim})."
)
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.state_size = state_size self.state_size = state_size

View File

@ -21,7 +21,6 @@ from typing import Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...generation import GenerationMixin from ...generation import GenerationMixin
@ -457,13 +456,19 @@ class Mamba2Mixer(nn.Module):
return out return out
# fmt: off # fmt: off
def torch_forward(self, input_states, cache_params: Optional[Mamba2Cache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.Tensor]=None): def torch_forward(
batch_size, seq_len, _ = input_states.shape self,
dtype = input_states.dtype hidden_states: torch.Tensor,
cache_params: Optional[Mamba2Cache]=None,
cache_position:Optional[torch.LongTensor]=None,
attention_mask: Optional[torch.Tensor]=None
):
batch_size, seq_len, _ = hidden_states.shape
dtype = hidden_states.dtype
# 1. Gated MLP's linear projection # 1. Gated MLP's linear projection
input_states = apply_mask_to_padding_states(input_states, attention_mask) hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
projected_states = self.in_proj(input_states) projected_states = self.in_proj(hidden_states)
d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.ssm_state_size-self.num_heads) // 2 d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.ssm_state_size-self.num_heads) // 2
_, _, gate, hidden_states_B_C, dt = projected_states.split( _, _, gate, hidden_states_B_C, dt = projected_states.split(
[d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
@ -657,11 +662,6 @@ class Mamba2Mixer(nn.Module):
): ):
if is_fast_path_available and "cuda" in self.in_proj.weight.device.type: if is_fast_path_available and "cuda" in self.in_proj.weight.device.type:
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask) return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
dtype = hidden_states.dtype
if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
# tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask) return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask)
@ -1018,7 +1018,7 @@ class Mamba2ForCausalLM(Mamba2PreTrainedModel, GenerationMixin):
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
cache_position: Optional[torch.Tensor] = None, cache_position: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
**kwargs, # for now we need this for generation **kwargs, # for now we need this for generation and loss_function
) -> Union[Tuple, Mamba2CausalLMOutput]: ) -> Union[Tuple, Mamba2CausalLMOutput]:
r""" r"""
cache_params (`Mamba2Cache`, *optional*): cache_params (`Mamba2Cache`, *optional*):
@ -1052,14 +1052,7 @@ class Mamba2ForCausalLM(Mamba2PreTrainedModel, GenerationMixin):
loss = None loss = None
if labels is not None: if labels is not None:
# move labels to correct device to enable model parallelism loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
labels = labels.to(logits.device)
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
if not return_dict: if not return_dict:
output = (logits,) + mamba2_outputs[1:] output = (logits,) + mamba2_outputs[1:]

View File

@ -44,6 +44,29 @@ if is_torch_available():
from transformers.models.mamba2.modeling_mamba2 import Mamba2Cache, Mamba2Mixer from transformers.models.mamba2.modeling_mamba2 import Mamba2Cache, Mamba2Mixer
class Mamba2ConfigTester(ConfigTester):
def _create_config(self, hidden_size: int, num_heads: int, expand: int, head_dim: int):
_input_dict = self.inputs_dict.copy()
_input_dict["hidden_size"] = hidden_size
_input_dict["num_heads"] = num_heads
_input_dict["expand"] = expand
_input_dict["head_dim"] = head_dim
return self.config_class(**_input_dict)
def test_hidden_size_compatibility(self):
self._create_config(hidden_size=2, num_heads=2, expand=2, head_dim=2)
self._create_config(hidden_size=4, num_heads=4, expand=2, head_dim=2)
self._create_config(hidden_size=2, num_heads=4, expand=4, head_dim=2)
with self.parent.assertRaises(ValueError):
self._create_config(hidden_size=2, num_heads=4, expand=2, head_dim=4)
with self.parent.assertRaises(ValueError):
self._create_config(hidden_size=4, num_heads=2, expand=4, head_dim=2)
def run_common_tests(self):
self.test_hidden_size_compatibility()
return super().run_common_tests()
class Mamba2ModelTester: class Mamba2ModelTester:
def __init__( def __init__(
self, self,
@ -233,7 +256,7 @@ class Mamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
def setUp(self): def setUp(self):
self.model_tester = Mamba2ModelTester(self) self.model_tester = Mamba2ModelTester(self)
self.config_tester = ConfigTester( self.config_tester = Mamba2ConfigTester(
self, config_class=Mamba2Config, n_embd=37, common_properties=["hidden_size", "num_hidden_layers"] self, config_class=Mamba2Config, n_embd=37, common_properties=["hidden_size", "num_hidden_layers"]
) )