mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 13:20:12 +06:00
Add config validation and style tweaks (#37589)
* Add config validation and style tweaks * Fix style issues * Fix style issues * style * Small fixes for copy/paste errors --------- Co-authored-by: Cyrile <cyrile.delestre@arkea.com>
This commit is contained in:
parent
1b00966395
commit
935bbbc711
@ -140,6 +140,13 @@ class Mamba2Config(PretrainedConfig):
|
||||
tie_word_embeddings=False,
|
||||
**kwargs,
|
||||
):
|
||||
if (hidden_size * expand) != (num_heads * head_dim):
|
||||
raise ValueError(
|
||||
"Inconsistent configuration: hidden_size * expand "
|
||||
f"({hidden_size * expand}) must equal num_heads * head_dim "
|
||||
f"({num_heads * head_dim})."
|
||||
)
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.state_size = state_size
|
||||
|
@ -21,7 +21,6 @@ from typing import Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...generation import GenerationMixin
|
||||
@ -457,13 +456,19 @@ class Mamba2Mixer(nn.Module):
|
||||
return out
|
||||
|
||||
# fmt: off
|
||||
def torch_forward(self, input_states, cache_params: Optional[Mamba2Cache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.Tensor]=None):
|
||||
batch_size, seq_len, _ = input_states.shape
|
||||
dtype = input_states.dtype
|
||||
def torch_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cache_params: Optional[Mamba2Cache]=None,
|
||||
cache_position:Optional[torch.LongTensor]=None,
|
||||
attention_mask: Optional[torch.Tensor]=None
|
||||
):
|
||||
batch_size, seq_len, _ = hidden_states.shape
|
||||
dtype = hidden_states.dtype
|
||||
|
||||
# 1. Gated MLP's linear projection
|
||||
input_states = apply_mask_to_padding_states(input_states, attention_mask)
|
||||
projected_states = self.in_proj(input_states)
|
||||
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
|
||||
projected_states = self.in_proj(hidden_states)
|
||||
d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.ssm_state_size-self.num_heads) // 2
|
||||
_, _, gate, hidden_states_B_C, dt = projected_states.split(
|
||||
[d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
|
||||
@ -657,11 +662,6 @@ class Mamba2Mixer(nn.Module):
|
||||
):
|
||||
if is_fast_path_available and "cuda" in self.in_proj.weight.device.type:
|
||||
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
|
||||
dtype = hidden_states.dtype
|
||||
if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
|
||||
# tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
|
||||
hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
|
||||
|
||||
return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask)
|
||||
|
||||
|
||||
@ -1018,7 +1018,7 @@ class Mamba2ForCausalLM(Mamba2PreTrainedModel, GenerationMixin):
|
||||
use_cache: Optional[bool] = None,
|
||||
cache_position: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs, # for now we need this for generation
|
||||
**kwargs, # for now we need this for generation and loss_function
|
||||
) -> Union[Tuple, Mamba2CausalLMOutput]:
|
||||
r"""
|
||||
cache_params (`Mamba2Cache`, *optional*):
|
||||
@ -1052,14 +1052,7 @@ class Mamba2ForCausalLM(Mamba2PreTrainedModel, GenerationMixin):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# move labels to correct device to enable model parallelism
|
||||
labels = labels.to(logits.device)
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + mamba2_outputs[1:]
|
||||
|
@ -44,6 +44,29 @@ if is_torch_available():
|
||||
from transformers.models.mamba2.modeling_mamba2 import Mamba2Cache, Mamba2Mixer
|
||||
|
||||
|
||||
class Mamba2ConfigTester(ConfigTester):
|
||||
def _create_config(self, hidden_size: int, num_heads: int, expand: int, head_dim: int):
|
||||
_input_dict = self.inputs_dict.copy()
|
||||
_input_dict["hidden_size"] = hidden_size
|
||||
_input_dict["num_heads"] = num_heads
|
||||
_input_dict["expand"] = expand
|
||||
_input_dict["head_dim"] = head_dim
|
||||
return self.config_class(**_input_dict)
|
||||
|
||||
def test_hidden_size_compatibility(self):
|
||||
self._create_config(hidden_size=2, num_heads=2, expand=2, head_dim=2)
|
||||
self._create_config(hidden_size=4, num_heads=4, expand=2, head_dim=2)
|
||||
self._create_config(hidden_size=2, num_heads=4, expand=4, head_dim=2)
|
||||
with self.parent.assertRaises(ValueError):
|
||||
self._create_config(hidden_size=2, num_heads=4, expand=2, head_dim=4)
|
||||
with self.parent.assertRaises(ValueError):
|
||||
self._create_config(hidden_size=4, num_heads=2, expand=4, head_dim=2)
|
||||
|
||||
def run_common_tests(self):
|
||||
self.test_hidden_size_compatibility()
|
||||
return super().run_common_tests()
|
||||
|
||||
|
||||
class Mamba2ModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
@ -233,7 +256,7 @@ class Mamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = Mamba2ModelTester(self)
|
||||
self.config_tester = ConfigTester(
|
||||
self.config_tester = Mamba2ConfigTester(
|
||||
self, config_class=Mamba2Config, n_embd=37, common_properties=["hidden_size", "num_hidden_layers"]
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user