mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Add the Bamba Model (#34982)
* initial commit for PR Co-authored-by: Gabe Goodhart <gabe.l.hart@gmail.com> * rename dynamic cache Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> * add more unit tests Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> * add integration test Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> * add integration test Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> * Add modular bamba file * Remove trainer changes from unrelated PR * Modify modular and cofig to get model running * Fix some CI errors and beam search * Fix a plethora of bugs from CI/docs/etc * Add bamba to models with special caches * Updat to newer mamba PR for mamba sublayer * fix test_left_padding_compatibility Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> * fix style Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> * fix remaining tests Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> * missed this test Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> * ran make style Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> * move slow tag to integration obj Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> * make style Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> * address comments Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> * fix modular Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> * left out one part of modular Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> * change model Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> * Make Rotary modular as well * Update bamba.md Added overview, update Model inference card and added config * Update bamba.md * Update bamba.md * Update bamba.md Minor fixes * Add docs for config and model back Signed-off-by: Antoni Viros i Martin <aviros@ibm.com> * Add warning when using fast kernels * replaced generate example Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> * Address comments from PR Signed-off-by: Antoni Viros i Martin <aviros@ibm.com> * Propagate attention fixes Signed-off-by: Antoni Viros i Martin <aviros@ibm.com> * Fix attention interfaces to the new API Signed-off-by: Antoni Viros i Martin <aviros@ibm.com> * Fix API for decoder layer Signed-off-by: Antoni Viros i Martin <aviros@ibm.com> * Remove extra weights Signed-off-by: Antoni Viros i Martin <aviros@ibm.com> --------- Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> Signed-off-by: Antoni Viros i Martin <aviros@ibm.com> Co-authored-by: Gabe Goodhart <gabe.l.hart@gmail.com> Co-authored-by: Antoni Viros i Martin <aviros@ibm.com> Co-authored-by: divya-kumari32 <72085811+divya-kumari32@users.noreply.github.com> Co-authored-by: Antoni Viros <ani300@gmail.com>
This commit is contained in:
parent
9a94dfe123
commit
9613933b02
@ -322,6 +322,8 @@
|
||||
sections:
|
||||
- local: model_doc/albert
|
||||
title: ALBERT
|
||||
- local: model_doc/bamba
|
||||
title: Bamba
|
||||
- local: model_doc/bart
|
||||
title: BART
|
||||
- local: model_doc/barthez
|
||||
|
@ -66,6 +66,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
| [AriaText](model_doc/aria_text) | ✅ | ❌ | ❌ |
|
||||
| [Audio Spectrogram Transformer](model_doc/audio-spectrogram-transformer) | ✅ | ❌ | ❌ |
|
||||
| [Autoformer](model_doc/autoformer) | ✅ | ❌ | ❌ |
|
||||
| [Bamba](model_doc/bamba) | ✅ | ❌ | ❌ |
|
||||
| [Bark](model_doc/bark) | ✅ | ❌ | ❌ |
|
||||
| [BART](model_doc/bart) | ✅ | ✅ | ✅ |
|
||||
| [BARThez](model_doc/barthez) | ✅ | ✅ | ✅ |
|
||||
|
64
docs/source/en/model_doc/bamba.md
Normal file
64
docs/source/en/model_doc/bamba.md
Normal file
@ -0,0 +1,64 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Bamba
|
||||
|
||||
|
||||
## Overview
|
||||
|
||||
Bamba-9B is a decoder-only language model based on the [Mamba-2](https://github.com/state-spaces/mamba) architecture and is designed to handle a wide range of text generation tasks. It is trained from scratch using a two-stage training approach. In the first stage, the model is trained on 2 trillion tokens from the Dolma v1.7 dataset. In the second stage, it undergoes additional training on 200 billion tokens, leveraging a carefully curated blend of high-quality data to further refine its performance and enhance output quality.
|
||||
|
||||
Checkout all Bamba-9B model checkpoints [here](https://github.com/foundation-model-stack/bamba).
|
||||
|
||||
## BambaConfig
|
||||
|
||||
| Model | Params | # Layers | Hidden Dim. | Attention Heads | GQA | KV Heads | Context Length | Tied Embeddings |
|
||||
|-------------------|--------------|----------|-------------|-----------------|-----|----------|----------------|------------------|
|
||||
| Bamba | 9B (9.78B) | 32 | 4096 | 32 | Yes | 8 | 4096 | True |
|
||||
|
||||
[[autodoc]] BambaConfig
|
||||
|
||||
<!---
|
||||
## Usage Tips
|
||||
|
||||
Tips:
|
||||
|
||||
- The architecture is based on Mamba-2 models.
|
||||
|
||||
## BambaModel
|
||||
|
||||
[[autodoc]] BambaModel
|
||||
- forward
|
||||
-->
|
||||
|
||||
## BambaForCausalLM
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("ibm-fms/Bamba-9B")
|
||||
tokenizer = AutoTokenizer.from_pretrained("ibm-fms/Bamba-9B")
|
||||
|
||||
message = ["Mamba is a snake with following properties "]
|
||||
inputs = tokenizer(message, return_tensors='pt', return_token_type_ids=False)
|
||||
response = model.generate(**inputs, max_new_tokens=64)
|
||||
print(tokenizer.batch_decode(response, skip_special_tokens=True)[0])
|
||||
```
|
||||
|
||||
[[autodoc]] BambaForCausalLM
|
||||
- forward
|
||||
|
||||
This HF implementation is contributed by [ani300](https://github.com/ani300) and [fabianlim](https://github.com/fabianlim).
|
@ -39,6 +39,7 @@ FlashAttention-2 is experimental and may change considerably in future versions.
|
||||
FlashAttention-2 is currently supported for the following architectures:
|
||||
* [Aria](https://huggingface.co/docs/transformers/model_doc/aria#transformers.AriaForConditionalGeneration)
|
||||
* [Bark](https://huggingface.co/docs/transformers/model_doc/bark#transformers.BarkModel)
|
||||
* [Bamba](https://huggingface.co/docs/transformers/model_doc/bamba#transformers.BambaModel)
|
||||
* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
|
||||
* [Chameleon](https://huggingface.co/docs/transformers/model_doc/chameleon#transformers.Chameleon)
|
||||
* [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPModel)
|
||||
@ -220,6 +221,7 @@ For now, Transformers supports SDPA inference and training for the following arc
|
||||
* [Albert](https://huggingface.co/docs/transformers/model_doc/albert#transformers.AlbertModel)
|
||||
* [Aria](https://huggingface.co/docs/transformers/model_doc/aria#transformers.AriaForConditionalGeneration)
|
||||
* [Audio Spectrogram Transformer](https://huggingface.co/docs/transformers/model_doc/audio-spectrogram-transformer#transformers.ASTModel)
|
||||
* [Bamba](https://huggingface.co/docs/transformers/model_doc/bamba#transformers.BambaModel)
|
||||
* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
|
||||
* [Beit](https://huggingface.co/docs/transformers/model_doc/beit#transformers.BeitModel)
|
||||
* [Bert](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel)
|
||||
|
@ -193,6 +193,7 @@ _import_structure = {
|
||||
"AutoTokenizer",
|
||||
],
|
||||
"models.autoformer": ["AutoformerConfig"],
|
||||
"models.bamba": ["BambaConfig"],
|
||||
"models.bark": [
|
||||
"BarkCoarseConfig",
|
||||
"BarkConfig",
|
||||
@ -1540,6 +1541,13 @@ else:
|
||||
"AutoformerPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.bamba"].extend(
|
||||
[
|
||||
"BambaForCausalLM",
|
||||
"BambaModel",
|
||||
"BambaPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.bark"].extend(
|
||||
[
|
||||
"BarkCausalModel",
|
||||
@ -5104,6 +5112,7 @@ if TYPE_CHECKING:
|
||||
from .models.autoformer import (
|
||||
AutoformerConfig,
|
||||
)
|
||||
from .models.bamba import BambaConfig
|
||||
from .models.bark import (
|
||||
BarkCoarseConfig,
|
||||
BarkConfig,
|
||||
@ -6493,6 +6502,7 @@ if TYPE_CHECKING:
|
||||
AutoformerModel,
|
||||
AutoformerPreTrainedModel,
|
||||
)
|
||||
from .models.bamba import BambaForCausalLM, BambaModel, BambaPreTrainedModel
|
||||
from .models.bark import (
|
||||
BarkCausalModel,
|
||||
BarkCoarseModel,
|
||||
|
@ -1693,6 +1693,7 @@ class GenerationMixin:
|
||||
self._supports_cache_class
|
||||
and "jamba" not in self.__class__.__name__.lower()
|
||||
and "zamba" not in self.__class__.__name__.lower()
|
||||
and "bamba" not in self.__class__.__name__.lower()
|
||||
)
|
||||
|
||||
def _prepare_cache_for_generation(
|
||||
|
@ -20,6 +20,7 @@ from . import (
|
||||
audio_spectrogram_transformer,
|
||||
auto,
|
||||
autoformer,
|
||||
bamba,
|
||||
bark,
|
||||
bart,
|
||||
barthez,
|
||||
|
@ -39,6 +39,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
("aria_text", "AriaTextConfig"),
|
||||
("audio-spectrogram-transformer", "ASTConfig"),
|
||||
("autoformer", "AutoformerConfig"),
|
||||
("bamba", "BambaConfig"),
|
||||
("bark", "BarkConfig"),
|
||||
("bart", "BartConfig"),
|
||||
("beit", "BeitConfig"),
|
||||
@ -337,6 +338,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("aria_text", "AriaText"),
|
||||
("audio-spectrogram-transformer", "Audio Spectrogram Transformer"),
|
||||
("autoformer", "Autoformer"),
|
||||
("bamba", "Bamba"),
|
||||
("bark", "Bark"),
|
||||
("bart", "BART"),
|
||||
("barthez", "BARThez"),
|
||||
|
@ -39,6 +39,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("aria_text", "AriaTextModel"),
|
||||
("audio-spectrogram-transformer", "ASTModel"),
|
||||
("autoformer", "AutoformerModel"),
|
||||
("bamba", "BambaModel"),
|
||||
("bark", "BarkModel"),
|
||||
("bart", "BartModel"),
|
||||
("beit", "BeitModel"),
|
||||
@ -471,6 +472,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Causal LM mapping
|
||||
("aria_text", "AriaTextForCausalLM"),
|
||||
("bamba", "BambaForCausalLM"),
|
||||
("bart", "BartForCausalLM"),
|
||||
("bert", "BertLMHeadModel"),
|
||||
("bert-generation", "BertGenerationDecoder"),
|
||||
|
28
src/transformers/models/bamba/__init__.py
Normal file
28
src/transformers/models/bamba/__init__.py
Normal file
@ -0,0 +1,28 @@
|
||||
# Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import _LazyModule
|
||||
from ...utils.import_utils import define_import_structure
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_bamba import *
|
||||
from .modeling_bamba import *
|
||||
from .processing_bamba import *
|
||||
else:
|
||||
import sys
|
||||
|
||||
_file = globals()["__file__"]
|
||||
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
206
src/transformers/models/bamba/configuration_bamba.py
Normal file
206
src/transformers/models/bamba/configuration_bamba.py
Normal file
@ -0,0 +1,206 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Bamba model configuration"""
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class BambaConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`BambaModel`]. It is used to instantiate a
|
||||
BambaModel model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with defaults taken from [ibm-fms/Bamba-9.8b-2.2T-hf](https://huggingface.co/ibm-fms/Bamba-9.8b-2.2T-hf).
|
||||
|
||||
The BambaModel is a hybrid [mamba2](https://github.com/state-spaces/mamba) architecture with SwiGLU.
|
||||
The checkpoints are jointly trained by IBM, Princeton, and UIUC.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 128000):
|
||||
Vocabulary size of the Bamba model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`BambaModel`]
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the
|
||||
model has a output word embedding layer.
|
||||
hidden_size (`int`, *optional*, defaults to 4096):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 14336):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 32):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
num_key_value_heads (`int`, *optional*, defaults to 8):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details checkout [this
|
||||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
|
||||
The epsilon used by the rms normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
num_logits_to_keep (`int` or `None`, *optional*, defaults to 1):
|
||||
Number of prompt logits to calculate during generation. If `None`, all logits will be calculated. If an
|
||||
integer value, only last `num_logits_to_keep` logits will be calculated. Default is 1 because only the
|
||||
logits of the last prompt token are needed for generation. For long sequences, the logits for the entire
|
||||
sequence may use a lot of memory so, setting `num_logits_to_keep=1` will reduce memory footprint
|
||||
significantly.
|
||||
pad_token_id (`int`, *optional*, defaults to 0):
|
||||
The id of the padding token.
|
||||
bos_token_id (`int`, *optional*, defaults to 1):
|
||||
The id of the "beginning-of-sequence" token.
|
||||
eos_token_id (`int`, *optional*, defaults to 2):
|
||||
The id of the "end-of-sequence" token.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 262144):
|
||||
Max cached sequence length for the model
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
attn_layer_indices (`list`, *optional*):
|
||||
Specifies the layer indices that will have full attention. Must contain values at most num_hidden_layers.
|
||||
mamba_n_heads (`int`, *optional*, defaults to 128):
|
||||
The number of mamba heads used in the v2 implementation.
|
||||
mamba_d_head (`int`, *optional*, defaults to `"auto"`):
|
||||
Head embeddding dimension size
|
||||
mamba_n_groups (`int`, *optional*, defaults to 1):
|
||||
The number of the mamba groups used in the v2 implementation.
|
||||
mamba_d_state (`int`, *optional*, defaults to 256):
|
||||
The dimension the mamba state space latents
|
||||
mamba_d_conv (`int`, *optional*, defaults to 4):
|
||||
The size of the mamba convolution kernel
|
||||
mamba_expand (`int`, *optional*, defaults to 2):
|
||||
Expanding factor (relative to hidden_size) used to determine the mamba intermediate size
|
||||
mamba_chunk_size (`int`, *optional*, defaults to 256):
|
||||
The chunks in which to break the sequence when doing prefill/training
|
||||
mamba_conv_bias (`bool`, *optional*, defaults to `True`):
|
||||
Flag indicating whether or not to use bias in the convolution layer of the mamba mixer block.
|
||||
mamba_proj_bias (`bool`, *optional*, defaults to `False`):
|
||||
Flag indicating whether or not to use bias in the input and output projections (["in_proj", "out_proj"]) of the mamba mixer block
|
||||
|
||||
"""
|
||||
|
||||
model_type = "bamba"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=128000,
|
||||
tie_word_embeddings=False,
|
||||
hidden_size=4096,
|
||||
intermediate_size=14336,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=32,
|
||||
num_key_value_heads=8,
|
||||
hidden_act="silu",
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-5,
|
||||
use_cache=True,
|
||||
num_logits_to_keep=1,
|
||||
pad_token_id=0,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
max_position_embeddings=262144,
|
||||
attention_dropout=0.0,
|
||||
attn_layer_indices=None,
|
||||
mamba_n_heads=128,
|
||||
mamba_d_head="auto",
|
||||
mamba_n_groups=1,
|
||||
mamba_d_state=256,
|
||||
mamba_d_conv=4,
|
||||
mamba_expand=2,
|
||||
mamba_chunk_size=256,
|
||||
mamba_conv_bias=True,
|
||||
mamba_proj_bias=False,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.tie_word_embeddings = tie_word_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.attention_dropout = attention_dropout
|
||||
self.attention_bias = False
|
||||
self.mlp_bias = False
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
|
||||
self.use_cache = use_cache
|
||||
self.num_logits_to_keep = num_logits_to_keep
|
||||
|
||||
self.attn_layer_indices = attn_layer_indices
|
||||
self.rope_theta = 10000.0
|
||||
self.rope_scaling = None
|
||||
self.partial_rotary_factor = 0.5
|
||||
|
||||
mamba_intermediate = mamba_expand * hidden_size
|
||||
|
||||
if mamba_intermediate % mamba_n_heads != 0:
|
||||
raise ValueError("mamba_n_heads must divide mamba_expand * hidden_size")
|
||||
|
||||
# for the mamba_v2, must satisfy the following
|
||||
if mamba_d_head == "auto":
|
||||
mamba_d_head = mamba_intermediate // mamba_n_heads
|
||||
|
||||
if mamba_d_head * mamba_n_heads != mamba_intermediate:
|
||||
raise ValueError("The dimensions for the Mamba head state do not match the model intermediate_size")
|
||||
|
||||
self.mamba_n_heads = mamba_n_heads
|
||||
self.mamba_d_head = mamba_d_head
|
||||
self.mamba_n_groups = mamba_n_groups
|
||||
self.mamba_d_state = mamba_d_state
|
||||
self.mamba_d_conv = mamba_d_conv
|
||||
self.mamba_expand = mamba_expand
|
||||
self.mamba_chunk_size = mamba_chunk_size
|
||||
self.mamba_conv_bias = mamba_conv_bias
|
||||
self.mamba_proj_bias = mamba_proj_bias
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def layers_block_type(self):
|
||||
return [
|
||||
"attention" if (self.attn_layer_indices and i in self.attn_layer_indices) else "mamba"
|
||||
for i in range(self.num_hidden_layers)
|
||||
]
|
||||
|
||||
|
||||
__all__ = ["BambaConfig"]
|
273
src/transformers/models/bamba/convert_mamba_ssm_checkpoint.py
Normal file
273
src/transformers/models/bamba/convert_mamba_ssm_checkpoint.py
Normal file
@ -0,0 +1,273 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""This script can be used to convert checkpoints provided in the `mamba_ssm` library into the format provided in HuggingFace `transformers`. It depends on the `mamba2_ssm` package to be installed."""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from os import path
|
||||
from typing import Dict, Union
|
||||
|
||||
import torch
|
||||
from huggingface_hub import split_torch_state_dict_into_shards
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME
|
||||
|
||||
from .configuration_bamba import BambaConfig
|
||||
|
||||
|
||||
def convert_state_dict_from_mamba_ssm(original_sd: Dict) -> Dict[str, torch.Tensor]:
|
||||
state_dict = {}
|
||||
|
||||
for orig_k, param in original_sd.items():
|
||||
k = orig_k.replace("backbone", "model")
|
||||
|
||||
# for embeddings
|
||||
k = k.replace("embedding", "embed_tokens")
|
||||
|
||||
# for mixer
|
||||
k = k.replace("mixer", "mamba")
|
||||
|
||||
# for final layernorm
|
||||
k = k.replace("norm_f", "final_layernorm")
|
||||
|
||||
# for block layernorm
|
||||
k = re.sub(r"(\d+)\.norm\.", r"\1.input_layernorm.", k)
|
||||
k = re.sub(r"(\d+)\.norm2\.", r"\1.pre_ff_layernorm.", k)
|
||||
|
||||
# for mlp
|
||||
k = k.replace("mlp.fc2", "feed_forward.down_proj")
|
||||
|
||||
if "mlp.fc1" in k:
|
||||
param, param2 = torch.chunk(param, 2, dim=0)
|
||||
k2 = k.replace("mlp.fc1", "feed_forward.gate_proj")
|
||||
state_dict[k2] = param2
|
||||
k = k.replace("mlp.fc1", "feed_forward.up_proj")
|
||||
|
||||
if ("in_proj" in k and orig_k.replace("in_proj", "conv1d") in original_sd) or (
|
||||
"out_proj" in k and orig_k.replace("out_proj", "conv1d") in original_sd
|
||||
):
|
||||
# then this must be a mamba
|
||||
pass
|
||||
else:
|
||||
# for attn
|
||||
# - because mixer was replaced to mamba above
|
||||
k = k.replace("mamba.out_proj", "self_attn.o_proj")
|
||||
if "mamba.in_proj" in k:
|
||||
m, n = param.shape
|
||||
d = (m - n) // 2
|
||||
param, param2, param3 = torch.split(param, [n, d, d], dim=0)
|
||||
k2 = k.replace("mamba.in_proj", "self_attn.k_proj")
|
||||
state_dict[k2] = param2
|
||||
k2 = k.replace("mamba.in_proj", "self_attn.v_proj")
|
||||
state_dict[k2] = param3
|
||||
k = k.replace("mamba.in_proj", "self_attn.q_proj")
|
||||
|
||||
state_dict[k] = param
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
# Adapted from transformers.models.mamba.convert_mamba_ssm_checkpoint_to_pytorch.py
|
||||
def convert_ssm_config_to_hf_config(
|
||||
config_ssm: Dict,
|
||||
**kwargs,
|
||||
) -> BambaConfig:
|
||||
"""Convert a config from mamba_ssm to a BambaConfig from here."""
|
||||
hf_config: BambaConfig = BambaConfig(**kwargs)
|
||||
|
||||
hf_config.architectures = ["BambaForCausalLM"]
|
||||
|
||||
# Set important values from config and recalculate other resulting entries
|
||||
hf_config.hidden_size = config_ssm["d_model"]
|
||||
hf_config.intermediate_size = config_ssm["d_intermediate"]
|
||||
hf_config.mamba_n_heads = (hf_config.hidden_size * hf_config.mamba_expand) // hf_config.mamba_d_head
|
||||
hf_config.num_hidden_layers = config_ssm["n_layer"]
|
||||
hf_config.tie_word_embeddings = config_ssm["tie_embeddings"]
|
||||
|
||||
# currently this script assumes config_ssm belongs to v2
|
||||
if config_ssm["ssm_cfg"].get("layer") != "Mamba2":
|
||||
raise ValueError("Conversion script only supports Mamba2")
|
||||
|
||||
# Set attention values
|
||||
attn_cfg = config_ssm.get("attn_cfg")
|
||||
if attn_cfg:
|
||||
assert attn_cfg["causal"], "Only support non-causal attention."
|
||||
assert not attn_cfg["qkv_proj_bias"], "Only support no qkv bias."
|
||||
assert not attn_cfg["out_proj_bias"], "Only support no out bias."
|
||||
hf_config.attn_rotary_emb = attn_cfg["rotary_emb_dim"]
|
||||
hf_config.num_attention_heads = attn_cfg["num_heads"]
|
||||
hf_config.num_key_value_heads = attn_cfg["num_heads_kv"]
|
||||
|
||||
attention_layer_indices = config_ssm.get("attn_layer_idx")
|
||||
if attention_layer_indices:
|
||||
hf_config.attn_layer_indices = attention_layer_indices
|
||||
|
||||
# Padded vocab size, mostly of 16 but 32 is also very common in different models
|
||||
vocab_size = config_ssm["vocab_size"]
|
||||
pad_vocab_size_multiple = config_ssm["pad_vocab_size_multiple"]
|
||||
if (vocab_size % pad_vocab_size_multiple) != 0:
|
||||
vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
|
||||
hf_config.vocab_size = vocab_size
|
||||
|
||||
return hf_config
|
||||
|
||||
|
||||
def save_single_safetensor(
|
||||
state_dict: Dict,
|
||||
save_directory: str,
|
||||
metadata: Dict,
|
||||
):
|
||||
save_file(
|
||||
state_dict,
|
||||
os.path.join(save_directory, SAFE_WEIGHTS_NAME),
|
||||
metadata,
|
||||
)
|
||||
|
||||
|
||||
def save_sharded_safetensors(
|
||||
state_dict: Dict,
|
||||
save_directory: str,
|
||||
metadata: Dict,
|
||||
max_shard_size: Union[int, str] = "5GB",
|
||||
):
|
||||
filename_pattern = SAFE_WEIGHTS_NAME.replace(".bin", "{suffix}.bin").replace(
|
||||
".safetensors", "{suffix}.safetensors"
|
||||
)
|
||||
state_dict_split = split_torch_state_dict_into_shards(
|
||||
state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
|
||||
)
|
||||
index = {
|
||||
"metadata": state_dict_split.metadata,
|
||||
"weight_map": state_dict_split.tensor_to_filename,
|
||||
}
|
||||
# Save the index
|
||||
with open(os.path.join(save_directory, SAFE_WEIGHTS_INDEX_NAME), "w", encoding="utf-8") as f:
|
||||
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
||||
f.write(content)
|
||||
|
||||
filename_to_tensors = state_dict_split.filename_to_tensors.items()
|
||||
for shard_file, tensors in filename_to_tensors:
|
||||
shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
|
||||
save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata)
|
||||
|
||||
|
||||
# Adapted from transformers.models.mamba.convert_mamba_ssm_checkpoint_to_pytorch.py
|
||||
def convert_mamba_ssm_checkpoint_file_to_huggingface_model_file(
|
||||
mamba_ssm_checkpoint_path: str,
|
||||
precision: str,
|
||||
output_dir: str,
|
||||
tokenizer_path: str = None,
|
||||
save_model: Union[bool, str] = True,
|
||||
) -> None:
|
||||
# load tokenizer if provided, this will be used to set the
|
||||
# token_ids in the config file
|
||||
token_ids = {}
|
||||
if tokenizer_path:
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
||||
for key in [
|
||||
"bos_token_id",
|
||||
"eos_token_id",
|
||||
"pad_token_id",
|
||||
]:
|
||||
id = getattr(tokenizer, key, None)
|
||||
if id:
|
||||
token_ids[key] = id
|
||||
|
||||
# there are some configs unsettable by mamba_ssn config, so
|
||||
# if there are changes from the defaults, have to pass them into
|
||||
# the function
|
||||
unsettables = {
|
||||
"mamba_d_head": 64,
|
||||
"mamba_d_state": 128,
|
||||
"mamba_n_groups": 1,
|
||||
"rms_norm_eps": 1e-5,
|
||||
}
|
||||
|
||||
# Load and save config based on name
|
||||
config_path = path.join(mamba_ssm_checkpoint_path, "config.json")
|
||||
with open(config_path, "r", encoding="utf-8") as json_file:
|
||||
config = json.load(json_file)
|
||||
|
||||
# convert the config
|
||||
hf_config = convert_ssm_config_to_hf_config(
|
||||
config_ssm=config,
|
||||
**token_ids,
|
||||
**unsettables,
|
||||
)
|
||||
hf_config.save_pretrained(output_dir)
|
||||
|
||||
# Load state dict of the original model and transfer to hf model
|
||||
state_dict = torch.load(
|
||||
path.join(mamba_ssm_checkpoint_path, "pytorch_model.bin"),
|
||||
map_location="cpu",
|
||||
weights_only=True,
|
||||
)
|
||||
# FIXME: allow other parameters to pass in
|
||||
state_dict = convert_state_dict_from_mamba_ssm(state_dict)
|
||||
|
||||
# Save new model to pytorch_dump_path
|
||||
dtype = torch.float32 if precision == "fp32" else (torch.bfloat16 if precision == "bf16" else torch.float16)
|
||||
|
||||
save_file_fn = None
|
||||
if isinstance(save_model, bool) and save_model:
|
||||
save_file_fn = save_single_safetensor
|
||||
elif isinstance(save_model, str) and save_model == "sharded":
|
||||
save_file_fn = save_sharded_safetensors
|
||||
|
||||
if save_file_fn:
|
||||
save_file_fn({k: v.to(dtype) for k, v in state_dict.items()}, output_dir, metadata={"format": "pt"})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-i",
|
||||
"--mamba_ssm_checkpoint_directory",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to a directory containing the `pytorch_model.bin` mamba_ssm checkpoint file to be converted.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--precision",
|
||||
type=str,
|
||||
default="fp16",
|
||||
const="fp16",
|
||||
required=True,
|
||||
choices=("fp32", "fp16", "bf16"),
|
||||
help="The precision the model will be saved in. Select from fp32, fp16 or bf16.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o", "--output_dir", type=str, required=True, help="Path to directory to save the converted output model to."
|
||||
)
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
"--tokenizer_model_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help="Path to a the tokenizer file.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_mamba_ssm_checkpoint_file_to_huggingface_model_file(
|
||||
args.mamba2_checkpoint_directory,
|
||||
args.precision,
|
||||
args.output_dir,
|
||||
)
|
1615
src/transformers/models/bamba/modeling_bamba.py
Normal file
1615
src/transformers/models/bamba/modeling_bamba.py
Normal file
File diff suppressed because it is too large
Load Diff
1303
src/transformers/models/bamba/modular_bamba.py
Normal file
1303
src/transformers/models/bamba/modular_bamba.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -1167,6 +1167,27 @@ class AutoformerPreTrainedModel(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class BambaForCausalLM(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class BambaModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class BambaPreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class BarkCausalModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
@ -2313,6 +2313,7 @@ class GenerationTesterMixin:
|
||||
# 2. We ignore models that have unique cache structures (e.g. mamba) or are in need of refatoring to match the
|
||||
# standard cache format (e.g.gptbigcode )
|
||||
models_without_standard_cache = (
|
||||
"bamba",
|
||||
"ctrl",
|
||||
"fsmt",
|
||||
"gptbigcode",
|
||||
|
0
tests/models/bamba/__init__.py
Normal file
0
tests/models/bamba/__init__.py
Normal file
603
tests/models/bamba/test_modeling_bamba.py
Normal file
603
tests/models/bamba/test_modeling_bamba.py
Normal file
@ -0,0 +1,603 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch Bamba model."""
|
||||
|
||||
import inspect
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import AutoTokenizer, BambaConfig, is_torch_available
|
||||
from transformers.testing_utils import (
|
||||
require_torch,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, ids_tensor
|
||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
BambaForCausalLM,
|
||||
BambaModel,
|
||||
)
|
||||
from transformers.models.bamba.modeling_bamba import (
|
||||
HybridMambaAttentionDynamicCache,
|
||||
)
|
||||
|
||||
|
||||
class BambaModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_input_mask=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=4,
|
||||
num_attention_heads=4,
|
||||
num_key_value_heads=2,
|
||||
intermediate_size=64,
|
||||
hidden_act="silu",
|
||||
attention_dropout=0.0,
|
||||
attn_layer_indices=None,
|
||||
attn_rotary_emb=8,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=16,
|
||||
initializer_range=0.02,
|
||||
num_labels=3,
|
||||
pad_token_id=0,
|
||||
mamba_n_groups=1,
|
||||
mamba_n_heads=16,
|
||||
mamba_d_state=16,
|
||||
mamba_d_conv=4,
|
||||
mamba_expand=2,
|
||||
mamba_chunk_size=16,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_input_mask = use_input_mask
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.attention_dropout = attention_dropout
|
||||
self.attn_layer_indices = attn_layer_indices
|
||||
self.attn_rotary_emb = attn_rotary_emb
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.initializer_range = initializer_range
|
||||
self.num_labels = num_labels
|
||||
self.pad_token_id = pad_token_id
|
||||
self.scope = scope
|
||||
self.mamba_n_groups = mamba_n_groups
|
||||
self.mamba_n_heads = mamba_n_heads
|
||||
self.mamba_d_state = mamba_d_state
|
||||
self.mamba_d_conv = mamba_d_conv
|
||||
self.mamba_expand = mamba_expand
|
||||
self.mamba_chunk_size = mamba_chunk_size
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
||||
|
||||
token_labels = None
|
||||
if self.use_labels:
|
||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
return config, input_ids, input_mask, token_labels
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
input_mask,
|
||||
token_labels,
|
||||
) = config_and_inputs
|
||||
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
def get_config(self):
|
||||
# Fix for SDPA tests, force at least 4 layers
|
||||
if self.num_hidden_layers < 4:
|
||||
self.num_hidden_layers = 4
|
||||
if self.attn_layer_indices is None:
|
||||
d = [x for x in range(2, self.num_hidden_layers) if self.num_hidden_layers % x == 0]
|
||||
if len(d) == 0:
|
||||
raise ValueError("num_hidden_layers is prime, cannot automatically set attn_layer_indices.")
|
||||
d = d[-1] # get the largest divisor
|
||||
self.attn_layer_indices = [x + 1 for x in range(0, self.num_hidden_layers, d)]
|
||||
|
||||
return BambaConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
num_key_value_heads=self.num_key_value_heads,
|
||||
intermediate_size=self.intermediate_size,
|
||||
hidden_act=self.hidden_act,
|
||||
attention_dropout=self.attention_dropout,
|
||||
attn_layer_indices=self.attn_layer_indices,
|
||||
attn_rotary_emb=self.attn_rotary_emb,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
initializer_range=self.initializer_range,
|
||||
pad_token_id=self.pad_token_id,
|
||||
mamba_n_groups=self.mamba_n_groups,
|
||||
mamba_n_heads=self.mamba_n_heads,
|
||||
mamba_d_state=self.mamba_d_state,
|
||||
mamba_d_conv=self.mamba_d_conv,
|
||||
mamba_expand=self.mamba_expand,
|
||||
mamba_chunk_size=self.mamba_chunk_size,
|
||||
)
|
||||
|
||||
def create_and_check_model(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
input_mask,
|
||||
token_labels,
|
||||
):
|
||||
model = BambaModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=input_mask)
|
||||
result = model(input_ids)
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
|
||||
def create_and_check_for_causal_lm(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
input_mask,
|
||||
token_labels,
|
||||
):
|
||||
model = BambaForCausalLM(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
|
||||
result = model(input_ids, attention_mask=input_mask)
|
||||
result = model(input_ids, labels=token_labels)
|
||||
result = model(input_ids)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
def create_and_check_decoder_model_past_large_inputs(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
input_mask,
|
||||
token_labels,
|
||||
):
|
||||
# config.is_decoder = True
|
||||
# config.add_cross_attention = True
|
||||
model = BambaForCausalLM(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
# first forward pass
|
||||
# Attention: Jamba needs the cache to be initialized to return a cache!
|
||||
past_key_values = HybridMambaAttentionDynamicCache(
|
||||
config, input_ids.shape[0], model.dtype, device=model.device
|
||||
)
|
||||
outputs = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
)
|
||||
past_key_values = outputs.past_key_values
|
||||
|
||||
# create hypothetical multiple next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
|
||||
next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)
|
||||
|
||||
# append to next input_ids and
|
||||
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||
next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)
|
||||
|
||||
output_from_no_past = model(
|
||||
next_input_ids,
|
||||
attention_mask=next_attention_mask,
|
||||
output_hidden_states=True,
|
||||
)["hidden_states"][0]
|
||||
output_from_past = model(
|
||||
next_tokens,
|
||||
attention_mask=next_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
output_hidden_states=True,
|
||||
cache_position=torch.arange(
|
||||
input_ids.shape[1], input_ids.shape[1] + next_tokens.shape[1], device=model.device
|
||||
),
|
||||
)["hidden_states"][0]
|
||||
|
||||
# select random slice
|
||||
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
|
||||
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
|
||||
|
||||
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
|
||||
|
||||
# test that outputs are equal for slice
|
||||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||
|
||||
|
||||
@require_torch
|
||||
class BambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(
|
||||
BambaModel,
|
||||
BambaForCausalLM,
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
all_generative_model_classes = (BambaForCausalLM,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"feature-extraction": BambaModel,
|
||||
"text-generation": BambaForCausalLM,
|
||||
}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
test_headmasking = False
|
||||
test_pruning = False
|
||||
fx_compatible = False
|
||||
|
||||
# Need to use `0.8` instead of `0.9` for `test_cpu_offload`
|
||||
# This is because we are hitting edge cases with the causal_mask buffer
|
||||
model_split_percents = [0.5, 0.7, 0.8]
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = BambaModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=BambaConfig, hidden_size=64)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_for_casual_lm(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_causal_lm(*config_and_inputs)
|
||||
|
||||
def test_decoder_model_past_with_large_inputs(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
|
||||
|
||||
def test_initialization(self):
|
||||
r"""
|
||||
Overriding the test_initialization test as the A_log and D params of the Bamba mixer are initialized differently
|
||||
"""
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
configs_no_init = _config_zero_init(config)
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config=configs_no_init)
|
||||
for name, param in model.named_parameters():
|
||||
if param.requires_grad:
|
||||
if "A_log" in name:
|
||||
A = torch.arange(1, config.mamba_n_heads + 1, dtype=torch.float32)[None, :]
|
||||
self.assertTrue(torch.allclose(param.data, torch.log(A), atol=1e-5, rtol=1e-5))
|
||||
elif "D" in name:
|
||||
D = torch.ones(config.mamba_n_heads, dtype=torch.float32)
|
||||
self.assertTrue(torch.allclose(param.data, D, atol=1e-5, rtol=1e-5))
|
||||
else:
|
||||
self.assertIn(
|
||||
((param.data.mean() * 1e9).round() / 1e9).item(),
|
||||
[0.0, 1.0],
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
||||
def test_mismatched_shapes_have_properly_initialized_weights(self):
|
||||
r"""
|
||||
Overriding the test_mismatched_shapes_have_properly_initialized_weights test because A_log and D params of the
|
||||
Bamba mixer are initialized differently and we tested that in test_initialization
|
||||
"""
|
||||
self.skipTest(reason="Cumbersome and redundant for Bamba")
|
||||
|
||||
def test_attention_outputs(self):
|
||||
r"""
|
||||
Overriding the test_attention_outputs test as the Bamba model outputs attention only for its attention layers
|
||||
"""
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.return_dict = True
|
||||
|
||||
seq_len = getattr(self.model_tester, "seq_length", None)
|
||||
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
|
||||
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
|
||||
|
||||
expected_num_attentions = self.model_tester.num_hidden_layers - len(self.model_tester.attn_layer_indices)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_attentions"] = True
|
||||
inputs_dict["output_hidden_states"] = False
|
||||
config.return_dict = True
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
attentions = outputs.attentions
|
||||
self.assertEqual(len(attentions), expected_num_attentions)
|
||||
|
||||
# check that output_attentions also work using config
|
||||
del inputs_dict["output_attentions"]
|
||||
config.output_attentions = True
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
attentions = outputs.attentions
|
||||
self.assertEqual(len(attentions), expected_num_attentions)
|
||||
|
||||
self.assertListEqual(
|
||||
list(attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
||||
)
|
||||
out_len = len(outputs)
|
||||
|
||||
# Check attention is always last and order is fine
|
||||
inputs_dict["output_attentions"] = True
|
||||
inputs_dict["output_hidden_states"] = True
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
added_hidden_states = 1
|
||||
self.assertEqual(out_len + added_hidden_states, len(outputs))
|
||||
|
||||
self_attentions = outputs.attentions
|
||||
|
||||
self.assertEqual(len(self_attentions), expected_num_attentions)
|
||||
self.assertListEqual(
|
||||
list(self_attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
||||
)
|
||||
|
||||
@unittest.skip(reason="Bamba has its own special cache type")
|
||||
@parameterized.expand([(1, False), (1, True), (4, False)])
|
||||
def test_new_cache_format(self, num_beams, do_sample):
|
||||
pass
|
||||
|
||||
def test_batching_equivalence(self):
|
||||
# need to disable the tril input mask
|
||||
orig = self.model_tester.use_input_mask
|
||||
self.model_tester.use_input_mask = False
|
||||
super().test_batching_equivalence()
|
||||
self.model_tester.use_input_mask = orig
|
||||
|
||||
# essentially the same test in test_utils, just adjustment for rtol for this model
|
||||
@pytest.mark.generate
|
||||
def test_left_padding_compatibility(self):
|
||||
# NOTE: left-padding results in small numerical differences. This is expected.
|
||||
# See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535
|
||||
|
||||
# First, filter out models that don't support left padding
|
||||
# - The model must have generative capabilities
|
||||
if len(self.all_generative_model_classes) == 0:
|
||||
self.skipTest(reason="No generative architecture available for this model.")
|
||||
|
||||
# - The model must support padding
|
||||
if not self.has_attentions:
|
||||
self.skipTest(reason="This model doesn't support padding.")
|
||||
|
||||
# - The model must be a decoder-only architecture (encoder-based architectures use right-padding)
|
||||
decoder_only_classes = []
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, _ = self.prepare_config_and_inputs_for_generate()
|
||||
if config.is_encoder_decoder:
|
||||
continue
|
||||
else:
|
||||
decoder_only_classes.append(model_class)
|
||||
if len(decoder_only_classes) == 0:
|
||||
self.skipTest(reason="No decoder-only architecture available for this model.")
|
||||
|
||||
# - Decoder-only architectures derived from encoder-decoder models could support it in theory, but we haven't
|
||||
# added support for it yet. We skip these models for now.
|
||||
has_encoder_attributes = any(
|
||||
attr_name
|
||||
for attr_name in config.to_dict().keys()
|
||||
if attr_name.startswith("encoder") and attr_name != "encoder_no_repeat_ngram_size"
|
||||
)
|
||||
if has_encoder_attributes:
|
||||
self.skipTest(
|
||||
reason="The decoder-only derived from encoder-decoder models are not expected to support left-padding."
|
||||
)
|
||||
|
||||
# Then, test left-padding
|
||||
def _prepare_model_kwargs(input_ids, attention_mask, signature):
|
||||
model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask}
|
||||
if "position_ids" in signature:
|
||||
position_ids = torch.cumsum(attention_mask, dim=-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
model_kwargs["position_ids"] = position_ids
|
||||
if "cache_position" in signature:
|
||||
cache_position = torch.arange(input_ids.shape[-1], device=torch_device)
|
||||
model_kwargs["cache_position"] = cache_position
|
||||
return model_kwargs
|
||||
|
||||
for model_class in decoder_only_classes:
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||
input_ids = inputs_dict["input_ids"]
|
||||
|
||||
# - for left padding we absolutely need to use an all ones
|
||||
# attention mask, so we do not use the one in inputs_dict
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
signature = inspect.signature(model.forward).parameters.keys()
|
||||
|
||||
# no cache as some models require special cache classes to be init outside forward
|
||||
model.generation_config.use_cache = False
|
||||
|
||||
# Without padding
|
||||
model_kwargs = _prepare_model_kwargs(input_ids, attention_mask, signature)
|
||||
next_logits_wo_padding = model(**model_kwargs).logits[:, -1, :]
|
||||
|
||||
# With left-padding (length 32)
|
||||
# can hardcode pad_token to be 0 as we'll do attn masking anyway
|
||||
pad_token_id = (
|
||||
config.get_text_config().pad_token_id if config.get_text_config().pad_token_id is not None else 0
|
||||
)
|
||||
pad_size = (input_ids.shape[0], 32)
|
||||
padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * pad_token_id
|
||||
padded_input_ids = torch.cat((padding, input_ids), dim=1)
|
||||
padded_attention_mask = torch.cat((torch.zeros_like(padding), attention_mask), dim=1)
|
||||
model_kwargs = _prepare_model_kwargs(padded_input_ids, padded_attention_mask, signature)
|
||||
next_logits_with_padding = model(**model_kwargs).logits[:, -1, :]
|
||||
|
||||
# They should result in very similar logits
|
||||
torch.testing.assert_close(next_logits_wo_padding, next_logits_with_padding, atol=1e-5, rtol=1e-1)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
class BambaModelIntegrationTest(unittest.TestCase):
|
||||
model = None
|
||||
tokenizer = None
|
||||
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
|
||||
# Depending on the hardware we get different logits / generations
|
||||
cuda_compute_capability_major_version = None
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
model_id = "ibm-fms/Bamba-9B"
|
||||
cls.model = BambaForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True)
|
||||
cls.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
# feels a bit forced to have to do this for the generation test
|
||||
cls.tokenizer.pad_token_id = cls.model.config.pad_token_id
|
||||
cls.tokenizer.padding_side = "left"
|
||||
|
||||
if is_torch_available() and torch.cuda.is_available():
|
||||
# 8 is for A100 / A10 and 7 for T4
|
||||
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
|
||||
|
||||
def test_simple_generate(self):
|
||||
# Key 9 for MI300, Key 8 for A100/A10, and Key 7 for T4.
|
||||
#
|
||||
# Note: Key 9 is currently set for MI300, but may need potential future adjustments for H100s,
|
||||
# considering differences in hardware processing and potential deviations in generated text.
|
||||
EXPECTED_TEXTS = {
|
||||
# 7: "",
|
||||
8: "<|begin_of_text|>Hey how are you doing on this lovely evening? I hope you are all having a good time.",
|
||||
# 9: """,
|
||||
}
|
||||
|
||||
self.model.to(torch_device)
|
||||
|
||||
input_ids = self.tokenizer("Hey how are you doing on this lovely evening?", return_tensors="pt")[
|
||||
"input_ids"
|
||||
].to(torch_device)
|
||||
out = self.model.generate(input_ids, do_sample=False, max_new_tokens=10)
|
||||
output_sentence = self.tokenizer.decode(out[0, :])
|
||||
self.assertEqual(output_sentence, EXPECTED_TEXTS[self.cuda_compute_capability_major_version])
|
||||
|
||||
# TODO: there are significant differences in the logits across major cuda versions, which shouldn't exist
|
||||
if self.cuda_compute_capability_major_version == 8:
|
||||
with torch.no_grad():
|
||||
logits = self.model(input_ids=input_ids, num_logits_to_keep=40).logits
|
||||
|
||||
EXPECTED_LOGITS_NO_GRAD = torch.tensor(
|
||||
[
|
||||
149., 142., 146., 142., 143., 144., 142., 145.,
|
||||
142., 146., 144., 146., 147., 147., 148., 145.,
|
||||
147., 145., 145., 145., 145., 144., 144., 144.,
|
||||
144., 145., 147., 146., 144., 144., 148., 147.,
|
||||
148., 147., 147., 147., 146., 146., 148., 148.
|
||||
], dtype=torch.bfloat16) # fmt: skip
|
||||
|
||||
torch.testing.assert_close(logits[0, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD, rtol=1e-3, atol=1)
|
||||
|
||||
def test_simple_batched_generate_with_padding(self):
|
||||
# Key 9 for MI300, Key 8 for A100/A10, and Key 7 for T4.
|
||||
#
|
||||
# Note: Key 9 is currently set for MI300, but may need potential future adjustments for H100s,
|
||||
# considering differences in hardware processing and potential deviations in generated text.
|
||||
EXPECTED_TEXTS = {
|
||||
7: [],
|
||||
8: [
|
||||
"<|begin_of_text|>Hey how are you doing on this lovely evening? I hope you are doing well. I am here",
|
||||
"!!!<|begin_of_text|>I am late! I need to get to work! I have to get to the",
|
||||
],
|
||||
9: [],
|
||||
}
|
||||
|
||||
self.model.to(torch_device)
|
||||
|
||||
inputs = self.tokenizer(
|
||||
["Hey how are you doing on this lovely evening?", "I am late! I need to"],
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
).to(torch_device)
|
||||
out = self.model.generate(**inputs, do_sample=False, max_new_tokens=10)
|
||||
output_sentences = self.tokenizer.batch_decode(out)
|
||||
self.assertEqual(output_sentences[0], EXPECTED_TEXTS[self.cuda_compute_capability_major_version][0])
|
||||
self.assertEqual(output_sentences[1], EXPECTED_TEXTS[self.cuda_compute_capability_major_version][1])
|
||||
|
||||
# TODO: there are significant differences in the logits across major cuda versions, which shouldn't exist
|
||||
if self.cuda_compute_capability_major_version == 8:
|
||||
with torch.no_grad():
|
||||
logits = self.model(input_ids=inputs["input_ids"]).logits
|
||||
|
||||
EXPECTED_LOGITS_NO_GRAD_0 = torch.tensor(
|
||||
[
|
||||
149., 142., 146., 142., 143., 144., 142., 145.,
|
||||
142., 146., 144., 146., 147., 147., 148., 145.,
|
||||
147., 145., 145., 145., 145., 144., 144., 144.,
|
||||
144., 145., 147., 146., 144., 144., 148., 147.,
|
||||
148., 147., 147., 147., 146., 146., 148., 148.
|
||||
], dtype=torch.bfloat16) # fmt: skip
|
||||
|
||||
EXPECTED_LOGITS_NO_GRAD_1 = torch.tensor(
|
||||
[
|
||||
182., 178., 177., 174., 176., 176., 178., 178.,
|
||||
177., 179., 176., 183., 180., 182., 179., 174.,
|
||||
178., 176., 176., 175., 175., 175., 174., 173.,
|
||||
174., 182., 180., 176., 177., 177., 180., 176.,
|
||||
178., 177., 177., 175., 176., 177., 175., 177.
|
||||
], dtype=torch.bfloat16) # fmt: skip
|
||||
|
||||
torch.testing.assert_close(logits[0, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD_0, rtol=1e-3, atol=1)
|
||||
torch.testing.assert_close(logits[1, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD_1, rtol=1e-3, atol=1)
|
@ -34,6 +34,9 @@ CONFIG_MAPPING = transformers.models.auto.configuration_auto.CONFIG_MAPPING
|
||||
SPECIAL_CASES_TO_ALLOW = {
|
||||
# 'max_position_embeddings' is not used in modeling file, but needed for eval frameworks like Huggingface's lighteval (https://github.com/huggingface/lighteval/blob/af24080ea4f16eaf1683e353042a2dfc9099f038/src/lighteval/models/base_model.py#L264).
|
||||
# periods and offsers are not used in modeling file, but used in the configuration file to define `layers_block_type` and `layers_num_experts`.
|
||||
"BambaConfig": [
|
||||
"attn_layer_indices",
|
||||
],
|
||||
"JambaConfig": [
|
||||
"max_position_embeddings",
|
||||
"attn_layer_offset",
|
||||
|
Loading…
Reference in New Issue
Block a user