Add Zamba2 (#34517)

* First commit

* Finish model implementation

* First commit

* Finish model implementation

* Register zamba2

* generated modeling and configuration

* generated modeling and configuration

* added hybrid cache

* fix attention_mask in mamba

* dropped unused loras

* fix flash2

* config docstrings

* fix config and fwd pass

* make fixup fixes

* text_modeling_zamba2

* small fixes

* make fixup fixes

* Fix modular model converter

* added inheritances in modular, renamed zamba cache

* modular rebase

* new modular conversion

* fix generated modeling file

* fixed import for Zamba2RMSNormGated

* modular file cleanup

* make fixup and model tests

* dropped inheritance for Zamba2PreTrainedModel

* make fixup and unit tests

* Add inheritance of rope from GemmaRotaryEmbedding

* moved rope to model init

* drop del self.self_attn and del self.feed_forward

* fix tests

* renamed lora -> adapter

* rewrote adapter implementation

* fixed tests

* Fix torch_forward in mamba2 layer

* Fix torch_forward in mamba2 layer

* Fix torch_forward in mamba2 layer

* Dropped adapter in-place sum

* removed rope from attention init

* updated rope

* created get_layers method

* make fixup fix

* make fixup fixes

* make fixup fixes

* update to new attention standard

* update to new attention standard

* make fixup fixes

* minor fixes

* cache_position

* removed cache_position postion_ids use_cache

* remove config from modular

* removed config from modular (2)

* import apply_rotary_pos_emb from llama

* fixed rope_kwargs

* Instantiate cache in Zamba2Model

* fix cache

* fix @slow decorator

* small fix in modular file

* Update docs/source/en/model_doc/zamba2.md

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* several minor fixes

* inherit mamba2decoder fwd and drop position_ids in mamba

* removed docstrings from modular

* reinstate zamba2 attention decoder fwd

* use regex for tied keys

* Revert "use regex for tied keys"

This reverts commit 9007a522b1.

* use regex for tied keys

* add cpu to slow forward tests

* dropped config.use_shared_mlp_adapter

* Update docs/source/en/model_doc/zamba2.md

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* re-convert from modular

---------

Co-authored-by: root <root@node-2.us-southcentral1-a.compute.internal>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
pglorio 2025-01-27 01:51:23 -08:00 committed by GitHub
parent 14a9bb520e
commit 33cb1f7b61
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 4148 additions and 12 deletions

View File

@ -385,6 +385,7 @@ Flax), PyTorch, and/or TensorFlow.
| [YOLOS](model_doc/yolos) | ✅ | ❌ | ❌ | | [YOLOS](model_doc/yolos) | ✅ | ❌ | ❌ |
| [YOSO](model_doc/yoso) | ✅ | ❌ | ❌ | | [YOSO](model_doc/yoso) | ✅ | ❌ | ❌ |
| [Zamba](model_doc/zamba) | ✅ | ❌ | ❌ | | [Zamba](model_doc/zamba) | ✅ | ❌ | ❌ |
| [Zamba2](model_doc/zamba2) | ✅ | ❌ | ❌ |
| [ZoeDepth](model_doc/zoedepth) | ✅ | ❌ | ❌ | | [ZoeDepth](model_doc/zoedepth) | ✅ | ❌ | ❌ |
<!-- End table--> <!-- End table-->

View File

@ -0,0 +1,91 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# Zamba2
Zamba2 is a large language model (LLM) trained by Zyphra, and made available under an Apache 2.0 license. Please see the [Zyphra Hugging Face](https://huggingface.co/collections/zyphra/) repository for model weights.
This model was contributed by [pglo](https://huggingface.co/pglo).
## Model details
Zamba2-1.2B, Zamba2-2.7B and Zamba2-7B are hybrid models combining state-space models (Specifically [Mamba](https://github.com/state-spaces/mamba)) and transformer, and were trained using next-token prediction. Zamba2 uses shared transformer layers after every 6 mamba blocks. It uses the [Mistral v0.1 tokenizer](https://huggingface.co/mistralai/Mistral-7B-v0.1). We came to this architecture after a series of ablations at small scales. Zamba2-1.2B, Zamba2-2.7B and Zamba2-7B were pre-trained on 2T and 3T tokens, respectively.
<img src=https://github.com/user-attachments/assets/c2cff209-b901-483c-87aa-774b82a0769f width=30% height=40% />
## Quick start
### Presequities
Zamba2 requires you use `transformers` version 4.48.0 or higher:
```bash
pip install transformers>=4.48.0
## Inference
```python
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
tokenizer = AutoTokenizer.from_pretrained("Zyphra/Zamba2-7B")
model = AutoModelForCausalLM.from_pretrained("Zyphra/Zamba2-7B", device_map="cuda", torch_dtype=torch.bfloat16)
input_text = "What factors contributed to the fall of the Roman Empire?"
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
outputs = model.generate(**input_ids, max_new_tokens=100)
print(tokenizer.decode(outputs[0]))
```
## Model card
The model cards can be found at:
* [Zamba2-1.2B](https://huggingface.co/Zyphra/Zamba2-1.2B)
* [Zamba2-2.7B](https://huggingface.co/Zyphra/Zamba2-2.7B)
* [Zamba2-7B](https://huggingface.co/Zyphra/Zamba2-7B)
## Issues
For issues with model output, or community discussion, please use the Hugging Face community [forum](https://huggingface.co/Zyphra/Zamba2-7B/discussions)
## License
The model weights are open-sourced via an Apache 2.0 license.
## Zamba2Config
[[autodoc]] Zamba2Config
## Zamba2Model
[[autodoc]] Zamba2Model
- forward
## Zamba2ForCausalLM
[[autodoc]] Zamba2ForCausalLM
- forward
## Zamba2ForSequenceClassification
[[autodoc]] transformers.Zamba2ForSequenceClassification
- forward

View File

@ -111,6 +111,7 @@ FlashAttention-2 is currently supported for the following architectures:
* [UniSpeech](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech#transformers.UniSpeechModel) * [UniSpeech](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech#transformers.UniSpeechModel)
* [unispeech_sat](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech-sat#transformers.UniSpeechSatModel) * [unispeech_sat](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech-sat#transformers.UniSpeechSatModel)
* [helium](https://huggingface.co/docs/transformers/main/en/model_doc/heliumtransformers.HeliumModel) * [helium](https://huggingface.co/docs/transformers/main/en/model_doc/heliumtransformers.HeliumModel)
* [Zamba2](https://huggingface.co/docs/transformers/model_doc/zamba2)
You can request to add FlashAttention-2 support for another model by opening a GitHub Issue or Pull Request. You can request to add FlashAttention-2 support for another model by opening a GitHub Issue or Pull Request.
@ -328,6 +329,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [XLM-RoBERTa-XL](https://huggingface.co/docs/transformers/model_doc/xlm-roberta-xl#transformers.XLMRobertaXLModel) * [XLM-RoBERTa-XL](https://huggingface.co/docs/transformers/model_doc/xlm-roberta-xl#transformers.XLMRobertaXLModel)
* [YOLOS](https://huggingface.co/docs/transformers/model_doc/yolos#transformers.YolosModel) * [YOLOS](https://huggingface.co/docs/transformers/model_doc/yolos#transformers.YolosModel)
* [helium](https://huggingface.co/docs/transformers/main/en/model_doc/heliumtransformers.HeliumModel) * [helium](https://huggingface.co/docs/transformers/main/en/model_doc/heliumtransformers.HeliumModel)
* [Zamba2](https://huggingface.co/docs/transformers/model_doc/zamba2)
<Tip> <Tip>

View File

@ -889,6 +889,7 @@ _import_structure = {
"models.yolos": ["YolosConfig"], "models.yolos": ["YolosConfig"],
"models.yoso": ["YosoConfig"], "models.yoso": ["YosoConfig"],
"models.zamba": ["ZambaConfig"], "models.zamba": ["ZambaConfig"],
"models.zamba2": ["Zamba2Config"],
"models.zoedepth": ["ZoeDepthConfig"], "models.zoedepth": ["ZoeDepthConfig"],
"onnx": [], "onnx": [],
"pipelines": [ "pipelines": [
@ -3989,6 +3990,14 @@ else:
"ZambaPreTrainedModel", "ZambaPreTrainedModel",
] ]
) )
_import_structure["models.zamba2"].extend(
[
"Zamba2ForCausalLM",
"Zamba2ForSequenceClassification",
"Zamba2Model",
"Zamba2PreTrainedModel",
]
)
_import_structure["models.zoedepth"].extend( _import_structure["models.zoedepth"].extend(
[ [
"ZoeDepthForDepthEstimation", "ZoeDepthForDepthEstimation",
@ -6004,6 +6013,7 @@ if TYPE_CHECKING:
from .models.yolos import YolosConfig from .models.yolos import YolosConfig
from .models.yoso import YosoConfig from .models.yoso import YosoConfig
from .models.zamba import ZambaConfig from .models.zamba import ZambaConfig
from .models.zamba2 import Zamba2Config
from .models.zoedepth import ZoeDepthConfig from .models.zoedepth import ZoeDepthConfig
# Pipelines # Pipelines
@ -8542,6 +8552,12 @@ if TYPE_CHECKING:
ZambaModel, ZambaModel,
ZambaPreTrainedModel, ZambaPreTrainedModel,
) )
from .models.zamba2 import (
Zamba2ForCausalLM,
Zamba2ForSequenceClassification,
Zamba2Model,
Zamba2PreTrainedModel,
)
from .models.zoedepth import ( from .models.zoedepth import (
ZoeDepthForDepthEstimation, ZoeDepthForDepthEstimation,
ZoeDepthPreTrainedModel, ZoeDepthPreTrainedModel,

View File

@ -303,5 +303,6 @@ from . import (
yolos, yolos,
yoso, yoso,
zamba, zamba,
zamba2,
zoedepth, zoedepth,
) )

View File

@ -335,6 +335,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("yolos", "YolosConfig"), ("yolos", "YolosConfig"),
("yoso", "YosoConfig"), ("yoso", "YosoConfig"),
("zamba", "ZambaConfig"), ("zamba", "ZambaConfig"),
("zamba2", "Zamba2Config"),
("zoedepth", "ZoeDepthConfig"), ("zoedepth", "ZoeDepthConfig"),
] ]
) )
@ -680,6 +681,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("yolos", "YOLOS"), ("yolos", "YOLOS"),
("yoso", "YOSO"), ("yoso", "YOSO"),
("zamba", "Zamba"), ("zamba", "Zamba"),
("zamba2", "Zamba2"),
("zoedepth", "ZoeDepth"), ("zoedepth", "ZoeDepth"),
] ]
) )

View File

@ -303,6 +303,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("yolos", "YolosModel"), ("yolos", "YolosModel"),
("yoso", "YosoModel"), ("yoso", "YosoModel"),
("zamba", "ZambaModel"), ("zamba", "ZambaModel"),
("zamba2", "Zamba2Model"),
] ]
) )
@ -577,6 +578,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
("xlnet", "XLNetLMHeadModel"), ("xlnet", "XLNetLMHeadModel"),
("xmod", "XmodForCausalLM"), ("xmod", "XmodForCausalLM"),
("zamba", "ZambaForCausalLM"), ("zamba", "ZambaForCausalLM"),
("zamba2", "Zamba2ForCausalLM"),
] ]
) )
@ -1055,6 +1057,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("xmod", "XmodForSequenceClassification"), ("xmod", "XmodForSequenceClassification"),
("yoso", "YosoForSequenceClassification"), ("yoso", "YosoForSequenceClassification"),
("zamba", "ZambaForSequenceClassification"), ("zamba", "ZambaForSequenceClassification"),
("zamba2", "Zamba2ForSequenceClassification"),
] ]
) )

View File

@ -583,6 +583,13 @@ else:
"LlamaTokenizerFast" if is_tokenizers_available() else None, "LlamaTokenizerFast" if is_tokenizers_available() else None,
), ),
), ),
(
"zamba2",
(
"LlamaTokenizer" if is_sentencepiece_available() else None,
"LlamaTokenizerFast" if is_tokenizers_available() else None,
),
),
] ]
) )

View File

@ -272,7 +272,6 @@ class ZambaAttention(nn.Module):
layer_idx: int, layer_idx: int,
attention_mask: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor],
past_key_value: Optional[ZambaHybridDynamicCache] = None, past_key_value: Optional[ZambaHybridDynamicCache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs], **kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1] input_shape = hidden_states.shape[:-1]
@ -621,11 +620,9 @@ class ZambaAttentionDecoderLayer(nn.Module):
original_hidden_states: torch.Tensor, original_hidden_states: torch.Tensor,
layer_idx: int, layer_idx: int,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[ZambaHybridDynamicCache] = None, past_key_value: Optional[ZambaHybridDynamicCache] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs], **kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
""" """
@ -638,7 +635,6 @@ class ZambaAttentionDecoderLayer(nn.Module):
layer_idx (`int`): layer_idx in the forward pass. Used to distinguish Zamba's tied transformer layers. layer_idx (`int`): layer_idx in the forward pass. Used to distinguish Zamba's tied transformer layers.
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, sequence_length)` where padding elements are indicated by 0. `(batch, sequence_length)` where padding elements are indicated by 0.
position_ids (`torch.LongTensor`, *optional*): token positions of shape `(batch, seq_len)`. Used for positional encodings.
past_key_value (`ZambaHybridDynamicCache`, *optional*): cached past key and value projection states past_key_value (`ZambaHybridDynamicCache`, *optional*): cached past key and value projection states
output_attentions (`bool`, *optional*): output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under Whether or not to return the attentions tensors of all attention layers. See `attentions` under
@ -655,11 +651,9 @@ class ZambaAttentionDecoderLayer(nn.Module):
hidden_states=hidden_states, hidden_states=hidden_states,
layer_idx=layer_idx, layer_idx=layer_idx,
attention_mask=attention_mask, attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value, past_key_value=past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
cache_position=cache_position,
**kwargs, **kwargs,
) )
# feed-forward (MLP) # feed-forward (MLP)
@ -688,12 +682,12 @@ class ZambaMambaDecoderLayer(nn.Module):
layer_idx: int = None, layer_idx: int = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
causal_mask: Optional[torch.Tensor] = None, causal_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[ZambaHybridDynamicCache] = None, past_key_value: Optional[ZambaHybridDynamicCache] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
transformer_hidden_states: Optional[torch.Tensor] = None, transformer_hidden_states: Optional[torch.Tensor] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
""" """
Args: Args:
@ -756,7 +750,6 @@ class ZambaHybridLayer(nn.Module):
layer_idx: int = None, layer_idx: int = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
causal_mask: Optional[torch.Tensor] = None, causal_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[ZambaHybridDynamicCache] = None, past_key_value: Optional[ZambaHybridDynamicCache] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
@ -786,7 +779,6 @@ class ZambaHybridLayer(nn.Module):
original_hidden_states=original_hidden_states, original_hidden_states=original_hidden_states,
layer_idx=layer_idx, layer_idx=layer_idx,
attention_mask=causal_mask, attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_value, past_key_value=past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
@ -804,7 +796,6 @@ class ZambaHybridLayer(nn.Module):
hidden_states, hidden_states,
transformer_hidden_states=transformer_hidden_states, transformer_hidden_states=transformer_hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value, past_key_value=past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
@ -1108,7 +1099,6 @@ class ZambaModel(ZambaPreTrainedModel):
layer_idx, layer_idx,
attention_mask, attention_mask,
causal_mask, causal_mask,
position_ids,
past_key_values, past_key_values,
output_attentions, output_attentions,
use_cache, use_cache,
@ -1121,7 +1111,6 @@ class ZambaModel(ZambaPreTrainedModel):
layer_idx=layer_idx, layer_idx=layer_idx,
attention_mask=attention_mask, attention_mask=attention_mask,
causal_mask=causal_mask, causal_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values, past_key_value=past_key_values,
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,

View File

@ -0,0 +1,27 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .configuration_zamba2 import *
from .modeling_zamba2 import *
else:
import sys
_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)

View File

@ -0,0 +1,236 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/zamba2/modular_zamba2.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_zamba2.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2024 Zyphra Technologies and the HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...configuration_utils import PretrainedConfig
class Zamba2Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Zamba2Model`]. It is used to instantiate a
Zamba2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the Zamba2 model.
[Zyphra/Zamba2-2.7B](https://huggingface.co/Zyphra/Zamba2-2.7B)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 32000):
Vocabulary size of the Zamba2 model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Zamba2Model`]
max_position_embeddings (`int`, *optional*, defaults to 4096):
The maximum sequence length that this model might ever be used with.
hidden_size (`int`, *optional*, defaults to 2560):
Dimension of the hidden representations.
num_hidden_layers (`int`, *optional*, defaults to 54):
Number of hidden layers in the model.
layers_block_type (`list`, *optional*):
List of layer types, which can be either "mamba" or "hybrid".
mamba_d_state (`int`, *optional*, defaults to 64): shape of the state space latents.
mamba_d_conv (`int`, *optional*, defaults to 4): Size of the convolution kernel.
mamba_expand (`int`, *optional*, defaults to 2): Expanding factor used to determine the intermediate size.
mamba_ngroups (`int`, *optional*, defaults to 1):
Number of groups for the evolution matrices of mamba 2.
time_step_min (`float`, *optional*, defaults to 0.001):
Minimum `time_step` used to bound `dt_proj.bias`.
time_step_max (`float`, *optional*, defaults to 0.1):
Maximum `time_step` used to bound `dt_proj.bias`.
time_step_floor (`float`, *optional*, defaults to 0.0001):
Minimum clamping value of the `dt_proj.bias` layer initialization.
time_step_limit (`tuple`, *optional*):
Accepted range of time step values.
n_mamba_heads (`int`, *optional*, defaults to 8):
Number of heads for the evolution matrices of mamba 2.
use_conv_bias (`bool`, *optional*, defaults to `True`):
Whether or not to use bias in the convolution layer of the mixer block.
chunk_size (`int`, *optional*, defaults to 256):
Size of the chunks that will comprise the sequence.
add_bias_linear (`bool`, *optional*, defaults to `False`):
Flag indicating whether or not to use bias in various layers
intermediate_size (`int`, *optional*, defaults to 4 * hidden_size):
Dimension of the MLP representations.
hidden_act (`str`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the MLP.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=None`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf).
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
num_mem_blocks (`int`, *optional*, defaults to 1):
Number of unshared transformer blocks.
use_shared_attention_adapter (`bool`, *optional*, defaults to `False`):
If True, unshared adapters (formally the same as LoRA but used in the base model) will be added to the q, k, v projectors in the shared attention layers.
adapter_rank (`int`, *optional*, defaults to 128):
Rank of the adapter in the shared MLP and shared attention layers.
use_mem_rope (`bool`, *optional*, defaults to `False`):
If True, includes RoPE in the shared attention layers.
rope_theta (`float`, *optional*, defaults to `10000.0`):
The base period of the RoPE embeddings.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
num_logits_to_keep (`int` or `None`, *optional*, defaults to 1):
Number of prompt logits to calculate during generation. If `None`, all logits will be calculated. If an
integer value, only last `num_logits_to_keep` logits will be calculated. Default is 1 because only the
logits of the last prompt token are needed for generation. For long sequences, the logits for the entire
sequence may use a lot of memory so, setting `num_logits_to_keep=1` will reduce memory footprint
significantly.
pad_token_id (`int`, *optional*, defaults to 0):
The id of the padding token.
bos_token_id (`int`, *optional*, defaults to 1):
The id of the "beginning-of-sequence" token.
eos_token_id (`int`, *optional*, defaults to 2):
The id of the "end-of-sequence" token.
use_long_context (`bool`, *optional*, defaults to `False`):
Activates the context-extended version of Zamba by modifying RoPE.
```python
>>> from transformers import Zamba2Model, Zamba2Config
>>> # Initializing a Zamba2-2.7B style configuration
>>> configuration = Zamba2Config()
>>> # Initializing a model from the Zamba2-2.7B style configuration
>>> model = Zamba2Model(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
"""
model_type = "zamba2"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=32000,
max_position_embeddings=4096,
hidden_size=2560,
num_hidden_layers=54,
layers_block_type=None,
mamba_d_state=64,
mamba_d_conv=4,
mamba_expand=2,
mamba_ngroups=1,
time_step_min=0.001,
time_step_max=0.1,
time_step_floor=1e-4,
time_step_limit=None,
n_mamba_heads=8,
use_conv_bias=True,
chunk_size=256,
add_bias_linear=False,
intermediate_size=None,
hidden_act="gelu",
num_attention_heads=32,
num_key_value_heads=None,
attention_dropout=0.0,
num_mem_blocks=1,
use_shared_attention_adapter=False,
adapter_rank=128,
use_mem_rope=False,
rope_theta=10000,
initializer_range=0.02,
rms_norm_eps=1e-5,
use_cache=True,
num_logits_to_keep=1,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
use_long_context=False,
**kwargs,
):
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
**kwargs,
)
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
if intermediate_size is None:
self.intermediate_size = 4 * hidden_size
else:
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_mem_blocks = num_mem_blocks
self.attention_hidden_size = 2 * hidden_size
self.attention_head_dim = 2 * self.hidden_size // self.num_attention_heads
self.attention_dropout = attention_dropout
self.use_mem_rope = use_mem_rope
self.use_long_context = use_long_context
if use_mem_rope and use_long_context:
a = 8
rope_theta = rope_theta * a ** (self.attention_head_dim / (self.attention_head_dim - 2))
self.rope_theta = rope_theta
self.mamba_d_state = mamba_d_state
self.mamba_d_conv = mamba_d_conv
self.mamba_expand = mamba_expand
self.add_bias_linear = add_bias_linear
self.mamba_ngroups = mamba_ngroups
self.n_mamba_heads = n_mamba_heads
self.mamba_headdim = int(mamba_expand * hidden_size) // n_mamba_heads
self.use_conv_bias = use_conv_bias
self.chunk_size = chunk_size
self.time_step_limit = time_step_limit
self.use_shared_attention_adapter = use_shared_attention_adapter
self.adapter_rank = adapter_rank
self.time_step_min = time_step_min
self.time_step_max = time_step_max
self.time_step_floor = time_step_floor
if use_long_context:
self.max_position_embeddings = 16384
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.num_attention_heads = num_attention_heads
self.kv_channels = self.hidden_size // self.num_attention_heads
self.num_query_groups = self.num_attention_heads
# Below, "mamba" stands for mamba layer, "hybrid" stands for hybrid layer (composed by a shared transformer followed by mamba layer)
if layers_block_type is None:
self.layers_block_type = (
["mamba"]
+ (["mamba"] * 5 + ["hybrid"]) * 7
+ ["mamba"] * 4
+ ["hybrid"]
+ ["mamba"] * 3
+ ["hybrid"]
+ ["mamba"] * 2
)
else:
self.layers_block_type = layers_block_type
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.num_logits_to_keep = num_logits_to_keep
self.hybrid_layer_ids = [index for index, type in enumerate(self.layers_block_type) if type == "hybrid"]
__all__ = ["Zamba2Config"]

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1435,6 +1435,7 @@ def set_model_tester_for_less_flaky_test(test_case):
# TODO (if possible): Avoid exceptional cases # TODO (if possible): Avoid exceptional cases
exceptional_classes = [ exceptional_classes = [
"ZambaModelTester", "ZambaModelTester",
"Zamba2ModelTester",
"RwkvModelTester", "RwkvModelTester",
"AriaVisionText2TextModelTester", "AriaVisionText2TextModelTester",
"GPTNeoModelTester", "GPTNeoModelTester",

View File

@ -10576,6 +10576,34 @@ class ZambaPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class Zamba2ForCausalLM(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Zamba2ForSequenceClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Zamba2Model(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Zamba2PreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class ZoeDepthForDepthEstimation(metaclass=DummyObject): class ZoeDepthForDepthEstimation(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]

View File

@ -2279,6 +2279,7 @@ class GenerationTesterMixin:
"mamba", "mamba",
"xlnet", "xlnet",
"zamba", "zamba",
"zamba2",
) )
has_standard_cache = not any( has_standard_cache = not any(
model_name in config.__class__.__name__.lower() for model_name in models_without_standard_cache model_name in config.__class__.__name__.lower() for model_name in models_without_standard_cache

View File

View File

@ -0,0 +1,666 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch Zamba model."""
import math
import tempfile
import unittest
import pytest
from parameterized import parameterized
from transformers import AutoTokenizer, Zamba2Config, is_torch_available
from transformers.testing_utils import (
require_bitsandbytes,
require_flash_attn,
require_torch,
require_torch_gpu,
slow,
torch_device,
)
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
import torch
from transformers import (
Zamba2ForCausalLM,
Zamba2ForSequenceClassification,
Zamba2Model,
)
from transformers.models.zamba2.modeling_zamba2 import (
Zamba2HybridDynamicCache,
)
class Zamba2ModelTester:
def __init__(
self,
parent,
batch_size=14,
seq_length=7,
is_training=True,
use_input_mask=True,
use_labels=True,
vocab_size=99,
hidden_size=16,
mamba_d_state=2,
chunk_size=8,
mamba_dt_rank="auto",
num_hidden_layers=2,
num_attention_heads=2,
n_mamba_heads=8,
mamba_ngroups=8,
intermediate_size=4,
hidden_act="gelu",
hidden_mamba_act="silu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
type_sequence_label_size=2,
initializer_range=0.02,
num_labels=3,
num_choices=4,
scope=None,
layers_block_type=["mamba", "hybrid"],
num_mem_blocks=1,
use_mem_rope=True,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_input_mask = use_input_mask
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.mamba_dt_rank = mamba_dt_rank
self.mamba_d_state = mamba_d_state
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.n_mamba_heads = n_mamba_heads
self.mamba_ngroups = mamba_ngroups
self.chunk_size = chunk_size
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_mamba_act = hidden_mamba_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.num_labels = num_labels
self.num_choices = num_choices
self.scope = scope
self.layers_block_type = layers_block_type
self.num_mem_blocks = num_mem_blocks
self.use_mem_rope = use_mem_rope
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
if self.use_input_mask:
input_mask = random_attention_mask([self.batch_size, self.seq_length])
sequence_labels = None
token_labels = None
choice_labels = None
if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = self.get_config()
return config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
def get_config(self):
return Zamba2Config(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
mamba_dt_rank=self.mamba_dt_rank,
mamba_d_state=self.mamba_d_state,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
n_mamba_heads=self.n_mamba_heads,
intermediate_size=self.intermediate_size,
chunk_size=self.chunk_size,
hidden_act=self.hidden_act,
mamba_ngroups=self.mamba_ngroups,
hidden_mamba_act=self.hidden_mamba_act,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size,
is_decoder=True,
initializer_range=self.initializer_range,
use_mamba_kernels=False,
layers_block_type=self.layers_block_type,
num_mem_blocks=self.num_mem_blocks,
use_mem_rope=self.use_mem_rope,
)
def prepare_config_and_inputs_for_decoder(self):
(
config,
input_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = self.prepare_config_and_inputs()
config.is_decoder = True
return (
config,
input_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
)
def create_and_check_model(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = Zamba2Model(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask)
result = model(input_ids)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def create_and_check_for_causal_lm(
self,
config,
input_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
):
model = Zamba2ForCausalLM(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
result = model(input_ids, attention_mask=input_mask)
result = model(input_ids, labels=token_labels)
result = model(input_ids)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_decoder_model_past_large_inputs(
self,
config,
input_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
):
config.is_decoder = True
config.add_cross_attention = False
model = Zamba2ForCausalLM(config=config)
model.to(torch_device)
model.eval()
# first forward pass
# Attention: Zamba2 needs the cache to be initialized to return a cache!
past_key_values = Zamba2HybridDynamicCache(config, input_ids.shape[0], model.dtype, device=model.device)
outputs = model(
input_ids,
attention_mask=input_mask,
past_key_values=past_key_values,
use_cache=True,
)
past_key_values = outputs.past_key_values
# create hypothetical multiple next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
next_mask = ids_tensor((self.batch_size, 1), vocab_size=2)
# append to next input_ids and
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)
output_from_no_past = model(
next_input_ids,
attention_mask=next_attention_mask,
output_hidden_states=True,
)["hidden_states"][0]
output_from_past = model(
next_tokens,
attention_mask=next_attention_mask,
past_key_values=past_key_values,
output_hidden_states=True,
cache_position=torch.arange(
input_ids.shape[1], input_ids.shape[1] + next_tokens.shape[1], device=model.device
),
)["hidden_states"][0]
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
output_from_no_past_slice = output_from_no_past[:, -1:, random_slice_idx].detach()
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
# test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
def create_and_check_for_sequence_classification(
self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_labels = self.num_labels
model = Zamba2ForSequenceClassification(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask, labels=sequence_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
input_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = config_and_inputs
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
return config, inputs_dict
@require_torch
class Zamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
test_torchscript = False
all_model_classes = (
(
Zamba2Model,
Zamba2ForCausalLM,
Zamba2ForSequenceClassification,
)
if is_torch_available()
else ()
)
all_generative_model_classes = (Zamba2ForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = (
{
"feature-extraction": Zamba2Model,
"text-classification": Zamba2ForSequenceClassification,
"text-generation": Zamba2ForCausalLM,
"zero-shot": Zamba2ForSequenceClassification,
}
if is_torch_available()
else {}
)
test_headmasking = False
test_pruning = False
def setUp(self):
self.model_tester = Zamba2ModelTester(self)
self.config_tester = ConfigTester(self, config_class=Zamba2Config, hidden_size=37)
@unittest.skip("position_ids cannot be used to pad due to Mamba2 layers")
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
pass
@unittest.skip("Zamba2 has a hybrid cache")
def test_past_key_values_format(self):
r"""
Zamba2's cache shape depends on whether a given layer is mamba or attention.
For mamba layers, the KV cache has shape is empty and has shape [batch_size, 0].
The shape checks of this test assume instead that every layer has an attention cache, so we skip it.
"""
pass
@unittest.skip(reason="A large mamba2 would be necessary (and costly) for that")
def test_multi_gpu_data_parallel_forward(self):
pass
def test_config(self):
self.config_tester.run_common_tests()
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_for_causal_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_causal_lm(*config_and_inputs)
def test_for_sequence_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
def test_decoder_model_past_with_large_inputs(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
def test_initialization(self):
r"""
Overriding the test_initialization test as the A_log and D params of the Mamba block are initialized differently
"""
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
configs_no_init = _config_zero_init(config)
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
for name, param in model.named_parameters():
if param.requires_grad:
if "A_log" in name:
A = torch.arange(1, config.n_mamba_heads + 1, dtype=torch.float32)[None, :]
self.assertTrue(torch.allclose(param.data, torch.log(A), atol=1e-5, rtol=1e-5))
elif "D" in name:
# check if it's a ones like
self.assertTrue(torch.allclose(param.data, torch.ones_like(param.data), atol=1e-5, rtol=1e-5))
elif "dt_bias" in name:
dt = torch.exp(
torch.tensor([0, 1]) * (math.log(config.time_step_max) - math.log(config.time_step_min))
+ math.log(config.time_step_min)
).clamp(min=config.time_step_floor)
inv_dt = dt + torch.log(-torch.expm1(-dt))
if param.requires_grad:
self.assertTrue(param.data.max().item() <= inv_dt[1])
self.assertTrue(param.data.min().item() >= inv_dt[0])
else:
self.assertIn(
((param.data.mean() * 1e9).round() / 1e9).item(),
[0.0, 1.0],
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
@unittest.skip(reason="Cumbersome and redundant for Zamba2")
def test_mismatched_shapes_have_properly_initialized_weights(self):
r"""
Overriding the test_mismatched_shapes_have_properly_initialized_weights test because A_log and D params of the
Mamba block are initialized differently and we tested that in test_initialization
"""
pass
def test_attention_outputs(self):
r"""
Overriding the test_attention_outputs test as the Zamba2 model outputs attention only for its attention layers
"""
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
seq_len = getattr(self.model_tester, "seq_length", None)
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = False
config.return_dict = True
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.attentions
# check that output_attentions also work using config
del inputs_dict["output_attentions"]
config.output_attentions = True
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.attentions
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)
out_len = len(outputs)
# Check attention is always last and order is fine
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = True
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
added_hidden_states = 1
self.assertEqual(out_len + added_hidden_states, len(outputs))
self_attentions = outputs.attentions
self.assertListEqual(
list(self_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)
def _get_input_ids_and_config(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
(
config,
input_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = config_and_inputs
return config, input_ids, input_mask
def test_left_padding_compatibility(self):
r"""
Overriding the test_left_padding_compatibility test as the mamba layers accentuate the numerical differences
effect of the left padding discussed in the issue in the note. Using a more permissive tolerance value.
"""
import inspect
# NOTE: left-padding results in small numerical differences. This is expected.
# See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535
# First, filter out models that don't support left padding - generative and decoder-only.
# Zamba2 is a decoder-only architecture
decoder_only_classes = self.all_generative_model_classes
# Then, test left-padding
def _prepare_model_kwargs(input_ids, attention_mask, signature):
model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask}
if "position_ids" in signature:
position_ids = torch.cumsum(attention_mask, dim=-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
model_kwargs["position_ids"] = position_ids
if "cache_position" in signature:
cache_position = torch.arange(input_ids.shape[-1], device=torch_device)
model_kwargs["cache_position"] = cache_position
return model_kwargs
for model_class in decoder_only_classes:
config, input_ids, attention_mask = self._get_input_ids_and_config()
model = model_class(config).to(torch_device).eval()
signature = inspect.signature(model.forward).parameters.keys()
# Without padding
model_kwargs = _prepare_model_kwargs(input_ids, attention_mask, signature)
next_logits_wo_padding = model(**model_kwargs).logits[:, -1, :]
# With left-padding (length 32)
pad_size = (input_ids.shape[0], 32)
padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * config.pad_token_id
padded_input_ids = torch.cat((padding, input_ids), dim=1)
padded_attention_mask = torch.cat((torch.zeros_like(padding), attention_mask), dim=1)
model_kwargs = _prepare_model_kwargs(padded_input_ids, padded_attention_mask, signature)
next_logits_with_padding = model(**model_kwargs).logits[:, -1, :]
# They should result in very similar logits
self.assertTrue(torch.allclose(next_logits_wo_padding, next_logits_with_padding, atol=3e-3))
@require_flash_attn
@require_torch_gpu
@require_bitsandbytes
@pytest.mark.flash_attn_test
@slow
def test_flash_attn_2_fp32_ln(self):
r"""
Overriding the test_flash_attn_2_fp32_ln test as the Zamba2 model, like Mixtral, doesn't support
right padding + use cache with FA2
"""
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
dummy_input = inputs_dict[model.main_input_name]
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
# NOTE: Zamba2 does not support right padding + use_cache with FA2.
dummy_attention_mask[:, -1] = 1
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
load_in_4bit=True,
)
for _, param in model.named_parameters():
# upcast only layer norms
if (param.dtype == torch.float16) or (param.dtype == torch.bfloat16):
param.data = param.data.to(torch.float32)
_ = model(dummy_input)
# with attention mask
_ = model(dummy_input, attention_mask=dummy_attention_mask)
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@slow
def test_flash_attn_2_inference_equivalence_right_padding(self):
r"""
Overriding the test_flash_attn_2_inference_padding_right test as the Zamba2 model, like Mixtral, doesn't support
right padding + use cache with FA2
"""
self.skipTest(reason="Zamba2 flash attention does not support right padding")
@unittest.skip(reason="Zamba2 has its own special cache type")
@parameterized.expand([(1, False), (1, True), (4, False)])
def test_new_cache_format(self, num_beams, do_sample):
pass
@require_torch
class Zamba2ModelIntegrationTest(unittest.TestCase):
model = None
tokenizer = None
@classmethod
@slow
def setUpClass(cls):
model_id = "Zyphra/Zamba2-1.2B"
cls.model = Zamba2ForCausalLM.from_pretrained(
model_id, torch_dtype=torch.float32, low_cpu_mem_usage=True, revision="PR"
)
cls.tokenizer = AutoTokenizer.from_pretrained(model_id, revision="PR")
@parameterized.expand([(torch_device,), ("cpu",)])
@slow
def test_simple_generate(self, torch_device):
self.model.to(torch_device)
input_ids = self.tokenizer("Hey how are you doing on this lovely evening?", return_tensors="pt")[
"input_ids"
].to(torch_device)
out = self.model.generate(input_ids, do_sample=False, max_new_tokens=10)
output_sentence = self.tokenizer.decode(out[0, :])
self.assertEqual(
output_sentence,
"<s> Hey how are you doing on this lovely evening?\n\nI'm doing well, thanks for",
)
with torch.no_grad():
logits = self.model(input_ids=input_ids).logits.to(dtype=torch.float32)
EXPECTED_LOGITS_NO_GRAD = torch.tensor(
[
-5.9587, 10.5152, 7.0382, -2.8728, -4.8143, -4.8142, -4.8142, -4.8144,
-4.8143, -4.8143, -4.8142, -4.8142, 6.0185, 18.0037, -4.8142, -4.8144,
-4.8143, -4.8142, -4.8143, -4.8143, -4.8143, -4.8143, -4.8142, -4.8143,
-4.8144, -4.8143, -4.8143, -4.8141, -4.8142, -4.8142, -4.8142, -4.8144,
-4.8143, -4.8143, -4.8143, -4.8142, -4.8144, -4.8144, -4.8142, -4.8142
]
, dtype=torch.float32) # fmt: skip
torch.testing.assert_close(logits[0, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD, rtol=1e-3, atol=1e-3)
@parameterized.expand([(torch_device,), ("cpu",)])
@slow
def test_simple_batched_generate_with_padding(self, torch_device):
self.model.to(torch_device)
inputs = self.tokenizer(
["Hey how are you doing on this lovely evening?", "When did the Roman empire "],
padding=True,
return_tensors="pt",
).to(torch_device)
out = self.model.generate(**inputs, do_sample=False, max_new_tokens=10)
output_sentences = self.tokenizer.batch_decode(out)
self.assertEqual(
output_sentences[0],
"<s> Hey how are you doing on this lovely evening?\n\nI'm doing well, thanks for",
)
self.assertEqual(
output_sentences[1],
"[PAD][PAD][PAD][PAD]<s> When did the Roman empire 1st fall?\nThe Roman Empire fell in",
)
with torch.no_grad():
logits = self.model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"]).logits.to(
dtype=torch.float32
)
EXPECTED_LOGITS_NO_GRAD_0 = torch.tensor(
[
-5.9611, 10.5208, 7.0411, -2.8743, -4.8167, -4.8167, -4.8167, -4.8168,
-4.8167, -4.8167, -4.8167, -4.8166, 6.0218, 18.0062, -4.8167, -4.8168,
-4.8167, -4.8167, -4.8167, -4.8168, -4.8168, -4.8168, -4.8167, -4.8167,
-4.8168, -4.8167, -4.8167, -4.8165, -4.8167, -4.8167, -4.8167, -4.8169,
-4.8168, -4.8168, -4.8168, -4.8166, -4.8169, -4.8168, -4.8167, -4.8167
]
, dtype=torch.float32) # fmt: skip
EXPECTED_LOGITS_NO_GRAD_1 = torch.tensor(
[
0.1966, 6.3449, 3.8350, -5.7291, -6.5106, -6.5104, -6.5103, -6.5104,
-6.5103, -6.5104, -6.5106, -6.5105, 7.8700, 13.5434, -6.5104, -6.5096,
-6.5106, -6.5102, -6.5106, -6.5106, -6.5105, -6.5106, -6.5104, -6.5106,
-6.5105, -6.5106, -6.5106, -6.5113, -6.5102, -6.5105, -6.5108, -6.5105,
-6.5104, -6.5106, -6.5106, -6.5104, -6.5106, -6.5107, -6.5103, -6.5105 ]
, dtype=torch.float32) # fmt: skip
torch.testing.assert_close(logits[0, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD_0, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(
logits[1, -1, :40].cpu(),
EXPECTED_LOGITS_NO_GRAD_1,
rtol=1e-3,
atol=6e-3 if torch_device == "cpu" else 1e-3,
)