mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Add config validation and style tweaks (#37589)
* Add config validation and style tweaks * Fix style issues * Fix style issues * style * Small fixes for copy/paste errors --------- Co-authored-by: Cyrile <cyrile.delestre@arkea.com>
This commit is contained in:
parent
1b00966395
commit
935bbbc711
@ -140,6 +140,13 @@ class Mamba2Config(PretrainedConfig):
|
|||||||
tie_word_embeddings=False,
|
tie_word_embeddings=False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
if (hidden_size * expand) != (num_heads * head_dim):
|
||||||
|
raise ValueError(
|
||||||
|
"Inconsistent configuration: hidden_size * expand "
|
||||||
|
f"({hidden_size * expand}) must equal num_heads * head_dim "
|
||||||
|
f"({num_heads * head_dim})."
|
||||||
|
)
|
||||||
|
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.state_size = state_size
|
self.state_size = state_size
|
||||||
|
@ -21,7 +21,6 @@ from typing import Optional, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import CrossEntropyLoss
|
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
@ -457,13 +456,19 @@ class Mamba2Mixer(nn.Module):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
def torch_forward(self, input_states, cache_params: Optional[Mamba2Cache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.Tensor]=None):
|
def torch_forward(
|
||||||
batch_size, seq_len, _ = input_states.shape
|
self,
|
||||||
dtype = input_states.dtype
|
hidden_states: torch.Tensor,
|
||||||
|
cache_params: Optional[Mamba2Cache]=None,
|
||||||
|
cache_position:Optional[torch.LongTensor]=None,
|
||||||
|
attention_mask: Optional[torch.Tensor]=None
|
||||||
|
):
|
||||||
|
batch_size, seq_len, _ = hidden_states.shape
|
||||||
|
dtype = hidden_states.dtype
|
||||||
|
|
||||||
# 1. Gated MLP's linear projection
|
# 1. Gated MLP's linear projection
|
||||||
input_states = apply_mask_to_padding_states(input_states, attention_mask)
|
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
|
||||||
projected_states = self.in_proj(input_states)
|
projected_states = self.in_proj(hidden_states)
|
||||||
d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.ssm_state_size-self.num_heads) // 2
|
d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.ssm_state_size-self.num_heads) // 2
|
||||||
_, _, gate, hidden_states_B_C, dt = projected_states.split(
|
_, _, gate, hidden_states_B_C, dt = projected_states.split(
|
||||||
[d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
|
[d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
|
||||||
@ -657,11 +662,6 @@ class Mamba2Mixer(nn.Module):
|
|||||||
):
|
):
|
||||||
if is_fast_path_available and "cuda" in self.in_proj.weight.device.type:
|
if is_fast_path_available and "cuda" in self.in_proj.weight.device.type:
|
||||||
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
|
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
|
||||||
dtype = hidden_states.dtype
|
|
||||||
if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
|
|
||||||
# tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
|
|
||||||
hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
|
|
||||||
|
|
||||||
return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask)
|
return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask)
|
||||||
|
|
||||||
|
|
||||||
@ -1018,7 +1018,7 @@ class Mamba2ForCausalLM(Mamba2PreTrainedModel, GenerationMixin):
|
|||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
cache_position: Optional[torch.Tensor] = None,
|
cache_position: Optional[torch.Tensor] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
**kwargs, # for now we need this for generation
|
**kwargs, # for now we need this for generation and loss_function
|
||||||
) -> Union[Tuple, Mamba2CausalLMOutput]:
|
) -> Union[Tuple, Mamba2CausalLMOutput]:
|
||||||
r"""
|
r"""
|
||||||
cache_params (`Mamba2Cache`, *optional*):
|
cache_params (`Mamba2Cache`, *optional*):
|
||||||
@ -1052,14 +1052,7 @@ class Mamba2ForCausalLM(Mamba2PreTrainedModel, GenerationMixin):
|
|||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
# move labels to correct device to enable model parallelism
|
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
||||||
labels = labels.to(logits.device)
|
|
||||||
# Shift so that tokens < n predict n
|
|
||||||
shift_logits = logits[..., :-1, :].contiguous()
|
|
||||||
shift_labels = labels[..., 1:].contiguous()
|
|
||||||
# Flatten the tokens
|
|
||||||
loss_fct = CrossEntropyLoss()
|
|
||||||
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + mamba2_outputs[1:]
|
output = (logits,) + mamba2_outputs[1:]
|
||||||
|
@ -44,6 +44,29 @@ if is_torch_available():
|
|||||||
from transformers.models.mamba2.modeling_mamba2 import Mamba2Cache, Mamba2Mixer
|
from transformers.models.mamba2.modeling_mamba2 import Mamba2Cache, Mamba2Mixer
|
||||||
|
|
||||||
|
|
||||||
|
class Mamba2ConfigTester(ConfigTester):
|
||||||
|
def _create_config(self, hidden_size: int, num_heads: int, expand: int, head_dim: int):
|
||||||
|
_input_dict = self.inputs_dict.copy()
|
||||||
|
_input_dict["hidden_size"] = hidden_size
|
||||||
|
_input_dict["num_heads"] = num_heads
|
||||||
|
_input_dict["expand"] = expand
|
||||||
|
_input_dict["head_dim"] = head_dim
|
||||||
|
return self.config_class(**_input_dict)
|
||||||
|
|
||||||
|
def test_hidden_size_compatibility(self):
|
||||||
|
self._create_config(hidden_size=2, num_heads=2, expand=2, head_dim=2)
|
||||||
|
self._create_config(hidden_size=4, num_heads=4, expand=2, head_dim=2)
|
||||||
|
self._create_config(hidden_size=2, num_heads=4, expand=4, head_dim=2)
|
||||||
|
with self.parent.assertRaises(ValueError):
|
||||||
|
self._create_config(hidden_size=2, num_heads=4, expand=2, head_dim=4)
|
||||||
|
with self.parent.assertRaises(ValueError):
|
||||||
|
self._create_config(hidden_size=4, num_heads=2, expand=4, head_dim=2)
|
||||||
|
|
||||||
|
def run_common_tests(self):
|
||||||
|
self.test_hidden_size_compatibility()
|
||||||
|
return super().run_common_tests()
|
||||||
|
|
||||||
|
|
||||||
class Mamba2ModelTester:
|
class Mamba2ModelTester:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -233,7 +256,7 @@ class Mamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
|||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = Mamba2ModelTester(self)
|
self.model_tester = Mamba2ModelTester(self)
|
||||||
self.config_tester = ConfigTester(
|
self.config_tester = Mamba2ConfigTester(
|
||||||
self, config_class=Mamba2Config, n_embd=37, common_properties=["hidden_size", "num_hidden_layers"]
|
self, config_class=Mamba2Config, n_embd=37, common_properties=["hidden_size", "num_hidden_layers"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user