From 935bbbc7111d71a869259b958c65df6a91f723a2 Mon Sep 17 00:00:00 2001 From: Kirire Date: Wed, 14 May 2025 14:22:10 +0200 Subject: [PATCH] Add config validation and style tweaks (#37589) * Add config validation and style tweaks * Fix style issues * Fix style issues * style * Small fixes for copy/paste errors --------- Co-authored-by: Cyrile --- .../models/mamba2/configuration_mamba2.py | 7 ++++ .../models/mamba2/modeling_mamba2.py | 33 ++++++++----------- tests/models/mamba2/test_modeling_mamba2.py | 25 +++++++++++++- 3 files changed, 44 insertions(+), 21 deletions(-) diff --git a/src/transformers/models/mamba2/configuration_mamba2.py b/src/transformers/models/mamba2/configuration_mamba2.py index ae6ea5cface..3b1b1177c0a 100644 --- a/src/transformers/models/mamba2/configuration_mamba2.py +++ b/src/transformers/models/mamba2/configuration_mamba2.py @@ -140,6 +140,13 @@ class Mamba2Config(PretrainedConfig): tie_word_embeddings=False, **kwargs, ): + if (hidden_size * expand) != (num_heads * head_dim): + raise ValueError( + "Inconsistent configuration: hidden_size * expand " + f"({hidden_size * expand}) must equal num_heads * head_dim " + f"({num_heads * head_dim})." + ) + self.vocab_size = vocab_size self.hidden_size = hidden_size self.state_size = state_size diff --git a/src/transformers/models/mamba2/modeling_mamba2.py b/src/transformers/models/mamba2/modeling_mamba2.py index f097aed9867..28e785b8604 100644 --- a/src/transformers/models/mamba2/modeling_mamba2.py +++ b/src/transformers/models/mamba2/modeling_mamba2.py @@ -21,7 +21,6 @@ from typing import Optional, Tuple, Union import torch import torch.utils.checkpoint from torch import nn -from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...generation import GenerationMixin @@ -457,13 +456,19 @@ class Mamba2Mixer(nn.Module): return out # fmt: off - def torch_forward(self, input_states, cache_params: Optional[Mamba2Cache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.Tensor]=None): - batch_size, seq_len, _ = input_states.shape - dtype = input_states.dtype + def torch_forward( + self, + hidden_states: torch.Tensor, + cache_params: Optional[Mamba2Cache]=None, + cache_position:Optional[torch.LongTensor]=None, + attention_mask: Optional[torch.Tensor]=None + ): + batch_size, seq_len, _ = hidden_states.shape + dtype = hidden_states.dtype # 1. Gated MLP's linear projection - input_states = apply_mask_to_padding_states(input_states, attention_mask) - projected_states = self.in_proj(input_states) + hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) + projected_states = self.in_proj(hidden_states) d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.ssm_state_size-self.num_heads) // 2 _, _, gate, hidden_states_B_C, dt = projected_states.split( [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 @@ -657,11 +662,6 @@ class Mamba2Mixer(nn.Module): ): if is_fast_path_available and "cuda" in self.in_proj.weight.device.type: return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask) - dtype = hidden_states.dtype - if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: - # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 - hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) - return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask) @@ -1018,7 +1018,7 @@ class Mamba2ForCausalLM(Mamba2PreTrainedModel, GenerationMixin): use_cache: Optional[bool] = None, cache_position: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, - **kwargs, # for now we need this for generation + **kwargs, # for now we need this for generation and loss_function ) -> Union[Tuple, Mamba2CausalLMOutput]: r""" cache_params (`Mamba2Cache`, *optional*): @@ -1052,14 +1052,7 @@ class Mamba2ForCausalLM(Mamba2PreTrainedModel, GenerationMixin): loss = None if labels is not None: - # move labels to correct device to enable model parallelism - labels = labels.to(logits.device) - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) if not return_dict: output = (logits,) + mamba2_outputs[1:] diff --git a/tests/models/mamba2/test_modeling_mamba2.py b/tests/models/mamba2/test_modeling_mamba2.py index 6d9b98ccedb..ee63e825e1f 100644 --- a/tests/models/mamba2/test_modeling_mamba2.py +++ b/tests/models/mamba2/test_modeling_mamba2.py @@ -44,6 +44,29 @@ if is_torch_available(): from transformers.models.mamba2.modeling_mamba2 import Mamba2Cache, Mamba2Mixer +class Mamba2ConfigTester(ConfigTester): + def _create_config(self, hidden_size: int, num_heads: int, expand: int, head_dim: int): + _input_dict = self.inputs_dict.copy() + _input_dict["hidden_size"] = hidden_size + _input_dict["num_heads"] = num_heads + _input_dict["expand"] = expand + _input_dict["head_dim"] = head_dim + return self.config_class(**_input_dict) + + def test_hidden_size_compatibility(self): + self._create_config(hidden_size=2, num_heads=2, expand=2, head_dim=2) + self._create_config(hidden_size=4, num_heads=4, expand=2, head_dim=2) + self._create_config(hidden_size=2, num_heads=4, expand=4, head_dim=2) + with self.parent.assertRaises(ValueError): + self._create_config(hidden_size=2, num_heads=4, expand=2, head_dim=4) + with self.parent.assertRaises(ValueError): + self._create_config(hidden_size=4, num_heads=2, expand=4, head_dim=2) + + def run_common_tests(self): + self.test_hidden_size_compatibility() + return super().run_common_tests() + + class Mamba2ModelTester: def __init__( self, @@ -233,7 +256,7 @@ class Mamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix def setUp(self): self.model_tester = Mamba2ModelTester(self) - self.config_tester = ConfigTester( + self.config_tester = Mamba2ConfigTester( self, config_class=Mamba2Config, n_embd=37, common_properties=["hidden_size", "num_hidden_layers"] )