mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 13:20:12 +06:00
Add Zamba2 (#34517)
* First commit
* Finish model implementation
* First commit
* Finish model implementation
* Register zamba2
* generated modeling and configuration
* generated modeling and configuration
* added hybrid cache
* fix attention_mask in mamba
* dropped unused loras
* fix flash2
* config docstrings
* fix config and fwd pass
* make fixup fixes
* text_modeling_zamba2
* small fixes
* make fixup fixes
* Fix modular model converter
* added inheritances in modular, renamed zamba cache
* modular rebase
* new modular conversion
* fix generated modeling file
* fixed import for Zamba2RMSNormGated
* modular file cleanup
* make fixup and model tests
* dropped inheritance for Zamba2PreTrainedModel
* make fixup and unit tests
* Add inheritance of rope from GemmaRotaryEmbedding
* moved rope to model init
* drop del self.self_attn and del self.feed_forward
* fix tests
* renamed lora -> adapter
* rewrote adapter implementation
* fixed tests
* Fix torch_forward in mamba2 layer
* Fix torch_forward in mamba2 layer
* Fix torch_forward in mamba2 layer
* Dropped adapter in-place sum
* removed rope from attention init
* updated rope
* created get_layers method
* make fixup fix
* make fixup fixes
* make fixup fixes
* update to new attention standard
* update to new attention standard
* make fixup fixes
* minor fixes
* cache_position
* removed cache_position postion_ids use_cache
* remove config from modular
* removed config from modular (2)
* import apply_rotary_pos_emb from llama
* fixed rope_kwargs
* Instantiate cache in Zamba2Model
* fix cache
* fix @slow decorator
* small fix in modular file
* Update docs/source/en/model_doc/zamba2.md
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
* several minor fixes
* inherit mamba2decoder fwd and drop position_ids in mamba
* removed docstrings from modular
* reinstate zamba2 attention decoder fwd
* use regex for tied keys
* Revert "use regex for tied keys"
This reverts commit 9007a522b1
.
* use regex for tied keys
* add cpu to slow forward tests
* dropped config.use_shared_mlp_adapter
* Update docs/source/en/model_doc/zamba2.md
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
* re-convert from modular
---------
Co-authored-by: root <root@node-2.us-southcentral1-a.compute.internal>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
parent
14a9bb520e
commit
33cb1f7b61
@ -385,6 +385,7 @@ Flax), PyTorch, and/or TensorFlow.
|
|||||||
| [YOLOS](model_doc/yolos) | ✅ | ❌ | ❌ |
|
| [YOLOS](model_doc/yolos) | ✅ | ❌ | ❌ |
|
||||||
| [YOSO](model_doc/yoso) | ✅ | ❌ | ❌ |
|
| [YOSO](model_doc/yoso) | ✅ | ❌ | ❌ |
|
||||||
| [Zamba](model_doc/zamba) | ✅ | ❌ | ❌ |
|
| [Zamba](model_doc/zamba) | ✅ | ❌ | ❌ |
|
||||||
|
| [Zamba2](model_doc/zamba2) | ✅ | ❌ | ❌ |
|
||||||
| [ZoeDepth](model_doc/zoedepth) | ✅ | ❌ | ❌ |
|
| [ZoeDepth](model_doc/zoedepth) | ✅ | ❌ | ❌ |
|
||||||
|
|
||||||
<!-- End table-->
|
<!-- End table-->
|
||||||
|
91
docs/source/en/model_doc/zamba2.md
Normal file
91
docs/source/en/model_doc/zamba2.md
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||||
|
the License. You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||||
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||||
|
specific language governing permissions and limitations under the License.
|
||||||
|
|
||||||
|
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||||
|
rendered properly in your Markdown viewer.
|
||||||
|
|
||||||
|
-->
|
||||||
|
# Zamba2
|
||||||
|
|
||||||
|
Zamba2 is a large language model (LLM) trained by Zyphra, and made available under an Apache 2.0 license. Please see the [Zyphra Hugging Face](https://huggingface.co/collections/zyphra/) repository for model weights.
|
||||||
|
|
||||||
|
This model was contributed by [pglo](https://huggingface.co/pglo).
|
||||||
|
|
||||||
|
|
||||||
|
## Model details
|
||||||
|
|
||||||
|
Zamba2-1.2B, Zamba2-2.7B and Zamba2-7B are hybrid models combining state-space models (Specifically [Mamba](https://github.com/state-spaces/mamba)) and transformer, and were trained using next-token prediction. Zamba2 uses shared transformer layers after every 6 mamba blocks. It uses the [Mistral v0.1 tokenizer](https://huggingface.co/mistralai/Mistral-7B-v0.1). We came to this architecture after a series of ablations at small scales. Zamba2-1.2B, Zamba2-2.7B and Zamba2-7B were pre-trained on 2T and 3T tokens, respectively.
|
||||||
|
|
||||||
|
<img src=https://github.com/user-attachments/assets/c2cff209-b901-483c-87aa-774b82a0769f width=30% height=40% />
|
||||||
|
|
||||||
|
## Quick start
|
||||||
|
|
||||||
|
|
||||||
|
### Presequities
|
||||||
|
|
||||||
|
Zamba2 requires you use `transformers` version 4.48.0 or higher:
|
||||||
|
```bash
|
||||||
|
pip install transformers>=4.48.0
|
||||||
|
## Inference
|
||||||
|
|
||||||
|
```python
|
||||||
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
|
import torch
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("Zyphra/Zamba2-7B")
|
||||||
|
model = AutoModelForCausalLM.from_pretrained("Zyphra/Zamba2-7B", device_map="cuda", torch_dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
input_text = "What factors contributed to the fall of the Roman Empire?"
|
||||||
|
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
|
||||||
|
|
||||||
|
outputs = model.generate(**input_ids, max_new_tokens=100)
|
||||||
|
print(tokenizer.decode(outputs[0]))
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Model card
|
||||||
|
|
||||||
|
The model cards can be found at:
|
||||||
|
* [Zamba2-1.2B](https://huggingface.co/Zyphra/Zamba2-1.2B)
|
||||||
|
* [Zamba2-2.7B](https://huggingface.co/Zyphra/Zamba2-2.7B)
|
||||||
|
* [Zamba2-7B](https://huggingface.co/Zyphra/Zamba2-7B)
|
||||||
|
|
||||||
|
|
||||||
|
## Issues
|
||||||
|
For issues with model output, or community discussion, please use the Hugging Face community [forum](https://huggingface.co/Zyphra/Zamba2-7B/discussions)
|
||||||
|
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
The model weights are open-sourced via an Apache 2.0 license.
|
||||||
|
|
||||||
|
|
||||||
|
## Zamba2Config
|
||||||
|
|
||||||
|
[[autodoc]] Zamba2Config
|
||||||
|
|
||||||
|
|
||||||
|
## Zamba2Model
|
||||||
|
|
||||||
|
[[autodoc]] Zamba2Model
|
||||||
|
- forward
|
||||||
|
|
||||||
|
|
||||||
|
## Zamba2ForCausalLM
|
||||||
|
|
||||||
|
[[autodoc]] Zamba2ForCausalLM
|
||||||
|
- forward
|
||||||
|
|
||||||
|
|
||||||
|
## Zamba2ForSequenceClassification
|
||||||
|
|
||||||
|
[[autodoc]] transformers.Zamba2ForSequenceClassification
|
||||||
|
- forward
|
@ -111,6 +111,7 @@ FlashAttention-2 is currently supported for the following architectures:
|
|||||||
* [UniSpeech](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech#transformers.UniSpeechModel)
|
* [UniSpeech](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech#transformers.UniSpeechModel)
|
||||||
* [unispeech_sat](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech-sat#transformers.UniSpeechSatModel)
|
* [unispeech_sat](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech-sat#transformers.UniSpeechSatModel)
|
||||||
* [helium](https://huggingface.co/docs/transformers/main/en/model_doc/heliumtransformers.HeliumModel)
|
* [helium](https://huggingface.co/docs/transformers/main/en/model_doc/heliumtransformers.HeliumModel)
|
||||||
|
* [Zamba2](https://huggingface.co/docs/transformers/model_doc/zamba2)
|
||||||
|
|
||||||
You can request to add FlashAttention-2 support for another model by opening a GitHub Issue or Pull Request.
|
You can request to add FlashAttention-2 support for another model by opening a GitHub Issue or Pull Request.
|
||||||
|
|
||||||
@ -328,6 +329,7 @@ For now, Transformers supports SDPA inference and training for the following arc
|
|||||||
* [XLM-RoBERTa-XL](https://huggingface.co/docs/transformers/model_doc/xlm-roberta-xl#transformers.XLMRobertaXLModel)
|
* [XLM-RoBERTa-XL](https://huggingface.co/docs/transformers/model_doc/xlm-roberta-xl#transformers.XLMRobertaXLModel)
|
||||||
* [YOLOS](https://huggingface.co/docs/transformers/model_doc/yolos#transformers.YolosModel)
|
* [YOLOS](https://huggingface.co/docs/transformers/model_doc/yolos#transformers.YolosModel)
|
||||||
* [helium](https://huggingface.co/docs/transformers/main/en/model_doc/heliumtransformers.HeliumModel)
|
* [helium](https://huggingface.co/docs/transformers/main/en/model_doc/heliumtransformers.HeliumModel)
|
||||||
|
* [Zamba2](https://huggingface.co/docs/transformers/model_doc/zamba2)
|
||||||
|
|
||||||
<Tip>
|
<Tip>
|
||||||
|
|
||||||
|
@ -889,6 +889,7 @@ _import_structure = {
|
|||||||
"models.yolos": ["YolosConfig"],
|
"models.yolos": ["YolosConfig"],
|
||||||
"models.yoso": ["YosoConfig"],
|
"models.yoso": ["YosoConfig"],
|
||||||
"models.zamba": ["ZambaConfig"],
|
"models.zamba": ["ZambaConfig"],
|
||||||
|
"models.zamba2": ["Zamba2Config"],
|
||||||
"models.zoedepth": ["ZoeDepthConfig"],
|
"models.zoedepth": ["ZoeDepthConfig"],
|
||||||
"onnx": [],
|
"onnx": [],
|
||||||
"pipelines": [
|
"pipelines": [
|
||||||
@ -3989,6 +3990,14 @@ else:
|
|||||||
"ZambaPreTrainedModel",
|
"ZambaPreTrainedModel",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
_import_structure["models.zamba2"].extend(
|
||||||
|
[
|
||||||
|
"Zamba2ForCausalLM",
|
||||||
|
"Zamba2ForSequenceClassification",
|
||||||
|
"Zamba2Model",
|
||||||
|
"Zamba2PreTrainedModel",
|
||||||
|
]
|
||||||
|
)
|
||||||
_import_structure["models.zoedepth"].extend(
|
_import_structure["models.zoedepth"].extend(
|
||||||
[
|
[
|
||||||
"ZoeDepthForDepthEstimation",
|
"ZoeDepthForDepthEstimation",
|
||||||
@ -6004,6 +6013,7 @@ if TYPE_CHECKING:
|
|||||||
from .models.yolos import YolosConfig
|
from .models.yolos import YolosConfig
|
||||||
from .models.yoso import YosoConfig
|
from .models.yoso import YosoConfig
|
||||||
from .models.zamba import ZambaConfig
|
from .models.zamba import ZambaConfig
|
||||||
|
from .models.zamba2 import Zamba2Config
|
||||||
from .models.zoedepth import ZoeDepthConfig
|
from .models.zoedepth import ZoeDepthConfig
|
||||||
|
|
||||||
# Pipelines
|
# Pipelines
|
||||||
@ -8542,6 +8552,12 @@ if TYPE_CHECKING:
|
|||||||
ZambaModel,
|
ZambaModel,
|
||||||
ZambaPreTrainedModel,
|
ZambaPreTrainedModel,
|
||||||
)
|
)
|
||||||
|
from .models.zamba2 import (
|
||||||
|
Zamba2ForCausalLM,
|
||||||
|
Zamba2ForSequenceClassification,
|
||||||
|
Zamba2Model,
|
||||||
|
Zamba2PreTrainedModel,
|
||||||
|
)
|
||||||
from .models.zoedepth import (
|
from .models.zoedepth import (
|
||||||
ZoeDepthForDepthEstimation,
|
ZoeDepthForDepthEstimation,
|
||||||
ZoeDepthPreTrainedModel,
|
ZoeDepthPreTrainedModel,
|
||||||
|
@ -303,5 +303,6 @@ from . import (
|
|||||||
yolos,
|
yolos,
|
||||||
yoso,
|
yoso,
|
||||||
zamba,
|
zamba,
|
||||||
|
zamba2,
|
||||||
zoedepth,
|
zoedepth,
|
||||||
)
|
)
|
||||||
|
@ -335,6 +335,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
|||||||
("yolos", "YolosConfig"),
|
("yolos", "YolosConfig"),
|
||||||
("yoso", "YosoConfig"),
|
("yoso", "YosoConfig"),
|
||||||
("zamba", "ZambaConfig"),
|
("zamba", "ZambaConfig"),
|
||||||
|
("zamba2", "Zamba2Config"),
|
||||||
("zoedepth", "ZoeDepthConfig"),
|
("zoedepth", "ZoeDepthConfig"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@ -680,6 +681,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
|||||||
("yolos", "YOLOS"),
|
("yolos", "YOLOS"),
|
||||||
("yoso", "YOSO"),
|
("yoso", "YOSO"),
|
||||||
("zamba", "Zamba"),
|
("zamba", "Zamba"),
|
||||||
|
("zamba2", "Zamba2"),
|
||||||
("zoedepth", "ZoeDepth"),
|
("zoedepth", "ZoeDepth"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
@ -303,6 +303,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
|||||||
("yolos", "YolosModel"),
|
("yolos", "YolosModel"),
|
||||||
("yoso", "YosoModel"),
|
("yoso", "YosoModel"),
|
||||||
("zamba", "ZambaModel"),
|
("zamba", "ZambaModel"),
|
||||||
|
("zamba2", "Zamba2Model"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -577,6 +578,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
|||||||
("xlnet", "XLNetLMHeadModel"),
|
("xlnet", "XLNetLMHeadModel"),
|
||||||
("xmod", "XmodForCausalLM"),
|
("xmod", "XmodForCausalLM"),
|
||||||
("zamba", "ZambaForCausalLM"),
|
("zamba", "ZambaForCausalLM"),
|
||||||
|
("zamba2", "Zamba2ForCausalLM"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1055,6 +1057,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
|||||||
("xmod", "XmodForSequenceClassification"),
|
("xmod", "XmodForSequenceClassification"),
|
||||||
("yoso", "YosoForSequenceClassification"),
|
("yoso", "YosoForSequenceClassification"),
|
||||||
("zamba", "ZambaForSequenceClassification"),
|
("zamba", "ZambaForSequenceClassification"),
|
||||||
|
("zamba2", "Zamba2ForSequenceClassification"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -583,6 +583,13 @@ else:
|
|||||||
"LlamaTokenizerFast" if is_tokenizers_available() else None,
|
"LlamaTokenizerFast" if is_tokenizers_available() else None,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
(
|
||||||
|
"zamba2",
|
||||||
|
(
|
||||||
|
"LlamaTokenizer" if is_sentencepiece_available() else None,
|
||||||
|
"LlamaTokenizerFast" if is_tokenizers_available() else None,
|
||||||
|
),
|
||||||
|
),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -272,7 +272,6 @@ class ZambaAttention(nn.Module):
|
|||||||
layer_idx: int,
|
layer_idx: int,
|
||||||
attention_mask: Optional[torch.Tensor],
|
attention_mask: Optional[torch.Tensor],
|
||||||
past_key_value: Optional[ZambaHybridDynamicCache] = None,
|
past_key_value: Optional[ZambaHybridDynamicCache] = None,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
**kwargs: Unpack[FlashAttentionKwargs],
|
**kwargs: Unpack[FlashAttentionKwargs],
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
input_shape = hidden_states.shape[:-1]
|
input_shape = hidden_states.shape[:-1]
|
||||||
@ -621,11 +620,9 @@ class ZambaAttentionDecoderLayer(nn.Module):
|
|||||||
original_hidden_states: torch.Tensor,
|
original_hidden_states: torch.Tensor,
|
||||||
layer_idx: int,
|
layer_idx: int,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_value: Optional[ZambaHybridDynamicCache] = None,
|
past_key_value: Optional[ZambaHybridDynamicCache] = None,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
**kwargs: Unpack[FlashAttentionKwargs],
|
**kwargs: Unpack[FlashAttentionKwargs],
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
"""
|
"""
|
||||||
@ -638,7 +635,6 @@ class ZambaAttentionDecoderLayer(nn.Module):
|
|||||||
layer_idx (`int`): layer_idx in the forward pass. Used to distinguish Zamba's tied transformer layers.
|
layer_idx (`int`): layer_idx in the forward pass. Used to distinguish Zamba's tied transformer layers.
|
||||||
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
||||||
`(batch, sequence_length)` where padding elements are indicated by 0.
|
`(batch, sequence_length)` where padding elements are indicated by 0.
|
||||||
position_ids (`torch.LongTensor`, *optional*): token positions of shape `(batch, seq_len)`. Used for positional encodings.
|
|
||||||
past_key_value (`ZambaHybridDynamicCache`, *optional*): cached past key and value projection states
|
past_key_value (`ZambaHybridDynamicCache`, *optional*): cached past key and value projection states
|
||||||
output_attentions (`bool`, *optional*):
|
output_attentions (`bool`, *optional*):
|
||||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||||
@ -655,11 +651,9 @@ class ZambaAttentionDecoderLayer(nn.Module):
|
|||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
layer_idx=layer_idx,
|
layer_idx=layer_idx,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_value=past_key_value,
|
past_key_value=past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
cache_position=cache_position,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
# feed-forward (MLP)
|
# feed-forward (MLP)
|
||||||
@ -688,12 +682,12 @@ class ZambaMambaDecoderLayer(nn.Module):
|
|||||||
layer_idx: int = None,
|
layer_idx: int = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
causal_mask: Optional[torch.Tensor] = None,
|
causal_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_value: Optional[ZambaHybridDynamicCache] = None,
|
past_key_value: Optional[ZambaHybridDynamicCache] = None,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
transformer_hidden_states: Optional[torch.Tensor] = None,
|
transformer_hidden_states: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -756,7 +750,6 @@ class ZambaHybridLayer(nn.Module):
|
|||||||
layer_idx: int = None,
|
layer_idx: int = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
causal_mask: Optional[torch.Tensor] = None,
|
causal_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_value: Optional[ZambaHybridDynamicCache] = None,
|
past_key_value: Optional[ZambaHybridDynamicCache] = None,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
@ -786,7 +779,6 @@ class ZambaHybridLayer(nn.Module):
|
|||||||
original_hidden_states=original_hidden_states,
|
original_hidden_states=original_hidden_states,
|
||||||
layer_idx=layer_idx,
|
layer_idx=layer_idx,
|
||||||
attention_mask=causal_mask,
|
attention_mask=causal_mask,
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_value=past_key_value,
|
past_key_value=past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
@ -804,7 +796,6 @@ class ZambaHybridLayer(nn.Module):
|
|||||||
hidden_states,
|
hidden_states,
|
||||||
transformer_hidden_states=transformer_hidden_states,
|
transformer_hidden_states=transformer_hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_value=past_key_value,
|
past_key_value=past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
@ -1108,7 +1099,6 @@ class ZambaModel(ZambaPreTrainedModel):
|
|||||||
layer_idx,
|
layer_idx,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
causal_mask,
|
causal_mask,
|
||||||
position_ids,
|
|
||||||
past_key_values,
|
past_key_values,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
use_cache,
|
use_cache,
|
||||||
@ -1121,7 +1111,6 @@ class ZambaModel(ZambaPreTrainedModel):
|
|||||||
layer_idx=layer_idx,
|
layer_idx=layer_idx,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
causal_mask=causal_mask,
|
causal_mask=causal_mask,
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_value=past_key_values,
|
past_key_value=past_key_values,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
|
27
src/transformers/models/zamba2/__init__.py
Normal file
27
src/transformers/models/zamba2/__init__.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from ...utils import _LazyModule
|
||||||
|
from ...utils.import_utils import define_import_structure
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .configuration_zamba2 import *
|
||||||
|
from .modeling_zamba2 import *
|
||||||
|
else:
|
||||||
|
import sys
|
||||||
|
|
||||||
|
_file = globals()["__file__"]
|
||||||
|
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
236
src/transformers/models/zamba2/configuration_zamba2.py
Normal file
236
src/transformers/models/zamba2/configuration_zamba2.py
Normal file
@ -0,0 +1,236 @@
|
|||||||
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
|
# This file was automatically generated from src/transformers/models/zamba2/modular_zamba2.py.
|
||||||
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||||
|
# the file from the modular. If any change should be done, please apply the change to the
|
||||||
|
# modular_zamba2.py file directly. One of our CI enforces this.
|
||||||
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 Zyphra Technologies and the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
|
class Zamba2Config(PretrainedConfig):
|
||||||
|
r"""
|
||||||
|
This is the configuration class to store the configuration of a [`Zamba2Model`]. It is used to instantiate a
|
||||||
|
Zamba2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||||
|
with the defaults will yield a similar configuration to that of the Zamba2 model.
|
||||||
|
|
||||||
|
[Zyphra/Zamba2-2.7B](https://huggingface.co/Zyphra/Zamba2-2.7B)
|
||||||
|
|
||||||
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
|
Args:
|
||||||
|
vocab_size (`int`, *optional*, defaults to 32000):
|
||||||
|
Vocabulary size of the Zamba2 model. Defines the number of different tokens that can be represented by the
|
||||||
|
`inputs_ids` passed when calling [`Zamba2Model`]
|
||||||
|
max_position_embeddings (`int`, *optional*, defaults to 4096):
|
||||||
|
The maximum sequence length that this model might ever be used with.
|
||||||
|
hidden_size (`int`, *optional*, defaults to 2560):
|
||||||
|
Dimension of the hidden representations.
|
||||||
|
num_hidden_layers (`int`, *optional*, defaults to 54):
|
||||||
|
Number of hidden layers in the model.
|
||||||
|
layers_block_type (`list`, *optional*):
|
||||||
|
List of layer types, which can be either "mamba" or "hybrid".
|
||||||
|
mamba_d_state (`int`, *optional*, defaults to 64): shape of the state space latents.
|
||||||
|
mamba_d_conv (`int`, *optional*, defaults to 4): Size of the convolution kernel.
|
||||||
|
mamba_expand (`int`, *optional*, defaults to 2): Expanding factor used to determine the intermediate size.
|
||||||
|
mamba_ngroups (`int`, *optional*, defaults to 1):
|
||||||
|
Number of groups for the evolution matrices of mamba 2.
|
||||||
|
time_step_min (`float`, *optional*, defaults to 0.001):
|
||||||
|
Minimum `time_step` used to bound `dt_proj.bias`.
|
||||||
|
time_step_max (`float`, *optional*, defaults to 0.1):
|
||||||
|
Maximum `time_step` used to bound `dt_proj.bias`.
|
||||||
|
time_step_floor (`float`, *optional*, defaults to 0.0001):
|
||||||
|
Minimum clamping value of the `dt_proj.bias` layer initialization.
|
||||||
|
time_step_limit (`tuple`, *optional*):
|
||||||
|
Accepted range of time step values.
|
||||||
|
n_mamba_heads (`int`, *optional*, defaults to 8):
|
||||||
|
Number of heads for the evolution matrices of mamba 2.
|
||||||
|
use_conv_bias (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not to use bias in the convolution layer of the mixer block.
|
||||||
|
chunk_size (`int`, *optional*, defaults to 256):
|
||||||
|
Size of the chunks that will comprise the sequence.
|
||||||
|
add_bias_linear (`bool`, *optional*, defaults to `False`):
|
||||||
|
Flag indicating whether or not to use bias in various layers
|
||||||
|
intermediate_size (`int`, *optional*, defaults to 4 * hidden_size):
|
||||||
|
Dimension of the MLP representations.
|
||||||
|
hidden_act (`str`, *optional*, defaults to `"gelu"`):
|
||||||
|
The non-linear activation function (function or string) in the MLP.
|
||||||
|
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||||
|
Number of attention heads for each attention layer in the Transformer decoder.
|
||||||
|
num_key_value_heads (`int`, *optional*):
|
||||||
|
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||||
|
`num_key_value_heads=None`, the model will use Multi Head Attention (MHA), if
|
||||||
|
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||||
|
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||||
|
by meanpooling all the original heads within that group. For more details checkout [this
|
||||||
|
paper](https://arxiv.org/pdf/2305.13245.pdf).
|
||||||
|
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||||
|
The dropout ratio for the attention probabilities.
|
||||||
|
num_mem_blocks (`int`, *optional*, defaults to 1):
|
||||||
|
Number of unshared transformer blocks.
|
||||||
|
use_shared_attention_adapter (`bool`, *optional*, defaults to `False`):
|
||||||
|
If True, unshared adapters (formally the same as LoRA but used in the base model) will be added to the q, k, v projectors in the shared attention layers.
|
||||||
|
adapter_rank (`int`, *optional*, defaults to 128):
|
||||||
|
Rank of the adapter in the shared MLP and shared attention layers.
|
||||||
|
use_mem_rope (`bool`, *optional*, defaults to `False`):
|
||||||
|
If True, includes RoPE in the shared attention layers.
|
||||||
|
rope_theta (`float`, *optional*, defaults to `10000.0`):
|
||||||
|
The base period of the RoPE embeddings.
|
||||||
|
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||||
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||||
|
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
|
||||||
|
The epsilon used by the rms normalization layers.
|
||||||
|
use_cache (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||||
|
relevant if `config.is_decoder=True`.
|
||||||
|
num_logits_to_keep (`int` or `None`, *optional*, defaults to 1):
|
||||||
|
Number of prompt logits to calculate during generation. If `None`, all logits will be calculated. If an
|
||||||
|
integer value, only last `num_logits_to_keep` logits will be calculated. Default is 1 because only the
|
||||||
|
logits of the last prompt token are needed for generation. For long sequences, the logits for the entire
|
||||||
|
sequence may use a lot of memory so, setting `num_logits_to_keep=1` will reduce memory footprint
|
||||||
|
significantly.
|
||||||
|
pad_token_id (`int`, *optional*, defaults to 0):
|
||||||
|
The id of the padding token.
|
||||||
|
bos_token_id (`int`, *optional*, defaults to 1):
|
||||||
|
The id of the "beginning-of-sequence" token.
|
||||||
|
eos_token_id (`int`, *optional*, defaults to 2):
|
||||||
|
The id of the "end-of-sequence" token.
|
||||||
|
use_long_context (`bool`, *optional*, defaults to `False`):
|
||||||
|
Activates the context-extended version of Zamba by modifying RoPE.
|
||||||
|
```python
|
||||||
|
>>> from transformers import Zamba2Model, Zamba2Config
|
||||||
|
>>> # Initializing a Zamba2-2.7B style configuration
|
||||||
|
>>> configuration = Zamba2Config()
|
||||||
|
>>> # Initializing a model from the Zamba2-2.7B style configuration
|
||||||
|
>>> model = Zamba2Model(configuration)
|
||||||
|
>>> # Accessing the model configuration
|
||||||
|
>>> configuration = model.config
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_type = "zamba2"
|
||||||
|
keys_to_ignore_at_inference = ["past_key_values"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=32000,
|
||||||
|
max_position_embeddings=4096,
|
||||||
|
hidden_size=2560,
|
||||||
|
num_hidden_layers=54,
|
||||||
|
layers_block_type=None,
|
||||||
|
mamba_d_state=64,
|
||||||
|
mamba_d_conv=4,
|
||||||
|
mamba_expand=2,
|
||||||
|
mamba_ngroups=1,
|
||||||
|
time_step_min=0.001,
|
||||||
|
time_step_max=0.1,
|
||||||
|
time_step_floor=1e-4,
|
||||||
|
time_step_limit=None,
|
||||||
|
n_mamba_heads=8,
|
||||||
|
use_conv_bias=True,
|
||||||
|
chunk_size=256,
|
||||||
|
add_bias_linear=False,
|
||||||
|
intermediate_size=None,
|
||||||
|
hidden_act="gelu",
|
||||||
|
num_attention_heads=32,
|
||||||
|
num_key_value_heads=None,
|
||||||
|
attention_dropout=0.0,
|
||||||
|
num_mem_blocks=1,
|
||||||
|
use_shared_attention_adapter=False,
|
||||||
|
adapter_rank=128,
|
||||||
|
use_mem_rope=False,
|
||||||
|
rope_theta=10000,
|
||||||
|
initializer_range=0.02,
|
||||||
|
rms_norm_eps=1e-5,
|
||||||
|
use_cache=True,
|
||||||
|
num_logits_to_keep=1,
|
||||||
|
pad_token_id=0,
|
||||||
|
bos_token_id=1,
|
||||||
|
eos_token_id=2,
|
||||||
|
use_long_context=False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
if intermediate_size is None:
|
||||||
|
self.intermediate_size = 4 * hidden_size
|
||||||
|
else:
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.num_mem_blocks = num_mem_blocks
|
||||||
|
self.attention_hidden_size = 2 * hidden_size
|
||||||
|
self.attention_head_dim = 2 * self.hidden_size // self.num_attention_heads
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
self.use_mem_rope = use_mem_rope
|
||||||
|
self.use_long_context = use_long_context
|
||||||
|
if use_mem_rope and use_long_context:
|
||||||
|
a = 8
|
||||||
|
rope_theta = rope_theta * a ** (self.attention_head_dim / (self.attention_head_dim - 2))
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.mamba_d_state = mamba_d_state
|
||||||
|
self.mamba_d_conv = mamba_d_conv
|
||||||
|
self.mamba_expand = mamba_expand
|
||||||
|
self.add_bias_linear = add_bias_linear
|
||||||
|
self.mamba_ngroups = mamba_ngroups
|
||||||
|
self.n_mamba_heads = n_mamba_heads
|
||||||
|
self.mamba_headdim = int(mamba_expand * hidden_size) // n_mamba_heads
|
||||||
|
self.use_conv_bias = use_conv_bias
|
||||||
|
self.chunk_size = chunk_size
|
||||||
|
self.time_step_limit = time_step_limit
|
||||||
|
self.use_shared_attention_adapter = use_shared_attention_adapter
|
||||||
|
self.adapter_rank = adapter_rank
|
||||||
|
self.time_step_min = time_step_min
|
||||||
|
self.time_step_max = time_step_max
|
||||||
|
self.time_step_floor = time_step_floor
|
||||||
|
if use_long_context:
|
||||||
|
self.max_position_embeddings = 16384
|
||||||
|
if num_key_value_heads is None:
|
||||||
|
num_key_value_heads = num_attention_heads
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.kv_channels = self.hidden_size // self.num_attention_heads
|
||||||
|
self.num_query_groups = self.num_attention_heads
|
||||||
|
# Below, "mamba" stands for mamba layer, "hybrid" stands for hybrid layer (composed by a shared transformer followed by mamba layer)
|
||||||
|
if layers_block_type is None:
|
||||||
|
self.layers_block_type = (
|
||||||
|
["mamba"]
|
||||||
|
+ (["mamba"] * 5 + ["hybrid"]) * 7
|
||||||
|
+ ["mamba"] * 4
|
||||||
|
+ ["hybrid"]
|
||||||
|
+ ["mamba"] * 3
|
||||||
|
+ ["hybrid"]
|
||||||
|
+ ["mamba"] * 2
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.layers_block_type = layers_block_type
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.num_logits_to_keep = num_logits_to_keep
|
||||||
|
self.hybrid_layer_ids = [index for index, type in enumerate(self.layers_block_type) if type == "hybrid"]
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["Zamba2Config"]
|
1909
src/transformers/models/zamba2/modeling_zamba2.py
Normal file
1909
src/transformers/models/zamba2/modeling_zamba2.py
Normal file
File diff suppressed because it is too large
Load Diff
1156
src/transformers/models/zamba2/modular_zamba2.py
Normal file
1156
src/transformers/models/zamba2/modular_zamba2.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -1435,6 +1435,7 @@ def set_model_tester_for_less_flaky_test(test_case):
|
|||||||
# TODO (if possible): Avoid exceptional cases
|
# TODO (if possible): Avoid exceptional cases
|
||||||
exceptional_classes = [
|
exceptional_classes = [
|
||||||
"ZambaModelTester",
|
"ZambaModelTester",
|
||||||
|
"Zamba2ModelTester",
|
||||||
"RwkvModelTester",
|
"RwkvModelTester",
|
||||||
"AriaVisionText2TextModelTester",
|
"AriaVisionText2TextModelTester",
|
||||||
"GPTNeoModelTester",
|
"GPTNeoModelTester",
|
||||||
|
@ -10576,6 +10576,34 @@ class ZambaPreTrainedModel(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class Zamba2ForCausalLM(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class Zamba2ForSequenceClassification(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class Zamba2Model(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class Zamba2PreTrainedModel(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class ZoeDepthForDepthEstimation(metaclass=DummyObject):
|
class ZoeDepthForDepthEstimation(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
@ -2279,6 +2279,7 @@ class GenerationTesterMixin:
|
|||||||
"mamba",
|
"mamba",
|
||||||
"xlnet",
|
"xlnet",
|
||||||
"zamba",
|
"zamba",
|
||||||
|
"zamba2",
|
||||||
)
|
)
|
||||||
has_standard_cache = not any(
|
has_standard_cache = not any(
|
||||||
model_name in config.__class__.__name__.lower() for model_name in models_without_standard_cache
|
model_name in config.__class__.__name__.lower() for model_name in models_without_standard_cache
|
||||||
|
0
tests/models/zamba2/__init__.py
Normal file
0
tests/models/zamba2/__init__.py
Normal file
666
tests/models/zamba2/test_modeling_zamba2.py
Normal file
666
tests/models/zamba2/test_modeling_zamba2.py
Normal file
@ -0,0 +1,666 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Testing suite for the PyTorch Zamba model."""
|
||||||
|
|
||||||
|
import math
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from parameterized import parameterized
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer, Zamba2Config, is_torch_available
|
||||||
|
from transformers.testing_utils import (
|
||||||
|
require_bitsandbytes,
|
||||||
|
require_flash_attn,
|
||||||
|
require_torch,
|
||||||
|
require_torch_gpu,
|
||||||
|
slow,
|
||||||
|
torch_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ...generation.test_utils import GenerationTesterMixin
|
||||||
|
from ...test_configuration_common import ConfigTester
|
||||||
|
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, ids_tensor, random_attention_mask
|
||||||
|
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
Zamba2ForCausalLM,
|
||||||
|
Zamba2ForSequenceClassification,
|
||||||
|
Zamba2Model,
|
||||||
|
)
|
||||||
|
from transformers.models.zamba2.modeling_zamba2 import (
|
||||||
|
Zamba2HybridDynamicCache,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Zamba2ModelTester:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
parent,
|
||||||
|
batch_size=14,
|
||||||
|
seq_length=7,
|
||||||
|
is_training=True,
|
||||||
|
use_input_mask=True,
|
||||||
|
use_labels=True,
|
||||||
|
vocab_size=99,
|
||||||
|
hidden_size=16,
|
||||||
|
mamba_d_state=2,
|
||||||
|
chunk_size=8,
|
||||||
|
mamba_dt_rank="auto",
|
||||||
|
num_hidden_layers=2,
|
||||||
|
num_attention_heads=2,
|
||||||
|
n_mamba_heads=8,
|
||||||
|
mamba_ngroups=8,
|
||||||
|
intermediate_size=4,
|
||||||
|
hidden_act="gelu",
|
||||||
|
hidden_mamba_act="silu",
|
||||||
|
hidden_dropout_prob=0.1,
|
||||||
|
attention_probs_dropout_prob=0.1,
|
||||||
|
max_position_embeddings=512,
|
||||||
|
type_vocab_size=16,
|
||||||
|
type_sequence_label_size=2,
|
||||||
|
initializer_range=0.02,
|
||||||
|
num_labels=3,
|
||||||
|
num_choices=4,
|
||||||
|
scope=None,
|
||||||
|
layers_block_type=["mamba", "hybrid"],
|
||||||
|
num_mem_blocks=1,
|
||||||
|
use_mem_rope=True,
|
||||||
|
):
|
||||||
|
self.parent = parent
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.seq_length = seq_length
|
||||||
|
self.is_training = is_training
|
||||||
|
self.use_input_mask = use_input_mask
|
||||||
|
self.use_labels = use_labels
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.mamba_dt_rank = mamba_dt_rank
|
||||||
|
self.mamba_d_state = mamba_d_state
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.n_mamba_heads = n_mamba_heads
|
||||||
|
self.mamba_ngroups = mamba_ngroups
|
||||||
|
self.chunk_size = chunk_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.hidden_mamba_act = hidden_mamba_act
|
||||||
|
self.hidden_dropout_prob = hidden_dropout_prob
|
||||||
|
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.type_vocab_size = type_vocab_size
|
||||||
|
self.type_sequence_label_size = type_sequence_label_size
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.num_labels = num_labels
|
||||||
|
self.num_choices = num_choices
|
||||||
|
self.scope = scope
|
||||||
|
self.layers_block_type = layers_block_type
|
||||||
|
self.num_mem_blocks = num_mem_blocks
|
||||||
|
self.use_mem_rope = use_mem_rope
|
||||||
|
|
||||||
|
def prepare_config_and_inputs(self):
|
||||||
|
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||||
|
|
||||||
|
input_mask = None
|
||||||
|
if self.use_input_mask:
|
||||||
|
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||||
|
|
||||||
|
sequence_labels = None
|
||||||
|
token_labels = None
|
||||||
|
choice_labels = None
|
||||||
|
if self.use_labels:
|
||||||
|
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||||
|
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
||||||
|
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
||||||
|
|
||||||
|
config = self.get_config()
|
||||||
|
|
||||||
|
return config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
return Zamba2Config(
|
||||||
|
vocab_size=self.vocab_size,
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
mamba_dt_rank=self.mamba_dt_rank,
|
||||||
|
mamba_d_state=self.mamba_d_state,
|
||||||
|
num_hidden_layers=self.num_hidden_layers,
|
||||||
|
num_attention_heads=self.num_attention_heads,
|
||||||
|
n_mamba_heads=self.n_mamba_heads,
|
||||||
|
intermediate_size=self.intermediate_size,
|
||||||
|
chunk_size=self.chunk_size,
|
||||||
|
hidden_act=self.hidden_act,
|
||||||
|
mamba_ngroups=self.mamba_ngroups,
|
||||||
|
hidden_mamba_act=self.hidden_mamba_act,
|
||||||
|
hidden_dropout_prob=self.hidden_dropout_prob,
|
||||||
|
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||||
|
max_position_embeddings=self.max_position_embeddings,
|
||||||
|
type_vocab_size=self.type_vocab_size,
|
||||||
|
is_decoder=True,
|
||||||
|
initializer_range=self.initializer_range,
|
||||||
|
use_mamba_kernels=False,
|
||||||
|
layers_block_type=self.layers_block_type,
|
||||||
|
num_mem_blocks=self.num_mem_blocks,
|
||||||
|
use_mem_rope=self.use_mem_rope,
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_config_and_inputs_for_decoder(self):
|
||||||
|
(
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
input_mask,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
) = self.prepare_config_and_inputs()
|
||||||
|
|
||||||
|
config.is_decoder = True
|
||||||
|
|
||||||
|
return (
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
input_mask,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_and_check_model(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||||
|
model = Zamba2Model(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
result = model(input_ids, attention_mask=input_mask)
|
||||||
|
result = model(input_ids)
|
||||||
|
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||||
|
|
||||||
|
def create_and_check_for_causal_lm(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
input_mask,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
):
|
||||||
|
model = Zamba2ForCausalLM(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
|
||||||
|
result = model(input_ids, attention_mask=input_mask)
|
||||||
|
result = model(input_ids, labels=token_labels)
|
||||||
|
result = model(input_ids)
|
||||||
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||||
|
|
||||||
|
def create_and_check_decoder_model_past_large_inputs(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
input_mask,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
):
|
||||||
|
config.is_decoder = True
|
||||||
|
config.add_cross_attention = False
|
||||||
|
model = Zamba2ForCausalLM(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# first forward pass
|
||||||
|
# Attention: Zamba2 needs the cache to be initialized to return a cache!
|
||||||
|
past_key_values = Zamba2HybridDynamicCache(config, input_ids.shape[0], model.dtype, device=model.device)
|
||||||
|
outputs = model(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=input_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
use_cache=True,
|
||||||
|
)
|
||||||
|
past_key_values = outputs.past_key_values
|
||||||
|
|
||||||
|
# create hypothetical multiple next token and extent to next_input_ids
|
||||||
|
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
||||||
|
next_mask = ids_tensor((self.batch_size, 1), vocab_size=2)
|
||||||
|
|
||||||
|
# append to next input_ids and
|
||||||
|
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||||
|
next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)
|
||||||
|
|
||||||
|
output_from_no_past = model(
|
||||||
|
next_input_ids,
|
||||||
|
attention_mask=next_attention_mask,
|
||||||
|
output_hidden_states=True,
|
||||||
|
)["hidden_states"][0]
|
||||||
|
output_from_past = model(
|
||||||
|
next_tokens,
|
||||||
|
attention_mask=next_attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
output_hidden_states=True,
|
||||||
|
cache_position=torch.arange(
|
||||||
|
input_ids.shape[1], input_ids.shape[1] + next_tokens.shape[1], device=model.device
|
||||||
|
),
|
||||||
|
)["hidden_states"][0]
|
||||||
|
|
||||||
|
# select random slice
|
||||||
|
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||||
|
output_from_no_past_slice = output_from_no_past[:, -1:, random_slice_idx].detach()
|
||||||
|
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
|
||||||
|
|
||||||
|
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
|
||||||
|
|
||||||
|
# test that outputs are equal for slice
|
||||||
|
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||||
|
|
||||||
|
def create_and_check_for_sequence_classification(
|
||||||
|
self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
):
|
||||||
|
config.num_labels = self.num_labels
|
||||||
|
model = Zamba2ForSequenceClassification(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
result = model(input_ids, attention_mask=input_mask, labels=sequence_labels)
|
||||||
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
|
||||||
|
|
||||||
|
def prepare_config_and_inputs_for_common(self):
|
||||||
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
|
(
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
input_mask,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
) = config_and_inputs
|
||||||
|
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
||||||
|
return config, inputs_dict
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class Zamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||||
|
test_torchscript = False
|
||||||
|
all_model_classes = (
|
||||||
|
(
|
||||||
|
Zamba2Model,
|
||||||
|
Zamba2ForCausalLM,
|
||||||
|
Zamba2ForSequenceClassification,
|
||||||
|
)
|
||||||
|
if is_torch_available()
|
||||||
|
else ()
|
||||||
|
)
|
||||||
|
all_generative_model_classes = (Zamba2ForCausalLM,) if is_torch_available() else ()
|
||||||
|
pipeline_model_mapping = (
|
||||||
|
{
|
||||||
|
"feature-extraction": Zamba2Model,
|
||||||
|
"text-classification": Zamba2ForSequenceClassification,
|
||||||
|
"text-generation": Zamba2ForCausalLM,
|
||||||
|
"zero-shot": Zamba2ForSequenceClassification,
|
||||||
|
}
|
||||||
|
if is_torch_available()
|
||||||
|
else {}
|
||||||
|
)
|
||||||
|
test_headmasking = False
|
||||||
|
test_pruning = False
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.model_tester = Zamba2ModelTester(self)
|
||||||
|
self.config_tester = ConfigTester(self, config_class=Zamba2Config, hidden_size=37)
|
||||||
|
|
||||||
|
@unittest.skip("position_ids cannot be used to pad due to Mamba2 layers")
|
||||||
|
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("Zamba2 has a hybrid cache")
|
||||||
|
def test_past_key_values_format(self):
|
||||||
|
r"""
|
||||||
|
Zamba2's cache shape depends on whether a given layer is mamba or attention.
|
||||||
|
For mamba layers, the KV cache has shape is empty and has shape [batch_size, 0].
|
||||||
|
The shape checks of this test assume instead that every layer has an attention cache, so we skip it.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="A large mamba2 would be necessary (and costly) for that")
|
||||||
|
def test_multi_gpu_data_parallel_forward(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_config(self):
|
||||||
|
self.config_tester.run_common_tests()
|
||||||
|
|
||||||
|
def test_model(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_for_causal_lm(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_for_causal_lm(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_for_sequence_classification(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_decoder_model_past_with_large_inputs(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
|
||||||
|
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_initialization(self):
|
||||||
|
r"""
|
||||||
|
Overriding the test_initialization test as the A_log and D params of the Mamba block are initialized differently
|
||||||
|
"""
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
configs_no_init = _config_zero_init(config)
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config=configs_no_init)
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if param.requires_grad:
|
||||||
|
if "A_log" in name:
|
||||||
|
A = torch.arange(1, config.n_mamba_heads + 1, dtype=torch.float32)[None, :]
|
||||||
|
self.assertTrue(torch.allclose(param.data, torch.log(A), atol=1e-5, rtol=1e-5))
|
||||||
|
elif "D" in name:
|
||||||
|
# check if it's a ones like
|
||||||
|
self.assertTrue(torch.allclose(param.data, torch.ones_like(param.data), atol=1e-5, rtol=1e-5))
|
||||||
|
elif "dt_bias" in name:
|
||||||
|
dt = torch.exp(
|
||||||
|
torch.tensor([0, 1]) * (math.log(config.time_step_max) - math.log(config.time_step_min))
|
||||||
|
+ math.log(config.time_step_min)
|
||||||
|
).clamp(min=config.time_step_floor)
|
||||||
|
inv_dt = dt + torch.log(-torch.expm1(-dt))
|
||||||
|
if param.requires_grad:
|
||||||
|
self.assertTrue(param.data.max().item() <= inv_dt[1])
|
||||||
|
self.assertTrue(param.data.min().item() >= inv_dt[0])
|
||||||
|
else:
|
||||||
|
self.assertIn(
|
||||||
|
((param.data.mean() * 1e9).round() / 1e9).item(),
|
||||||
|
[0.0, 1.0],
|
||||||
|
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||||
|
)
|
||||||
|
|
||||||
|
@unittest.skip(reason="Cumbersome and redundant for Zamba2")
|
||||||
|
def test_mismatched_shapes_have_properly_initialized_weights(self):
|
||||||
|
r"""
|
||||||
|
Overriding the test_mismatched_shapes_have_properly_initialized_weights test because A_log and D params of the
|
||||||
|
Mamba block are initialized differently and we tested that in test_initialization
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_attention_outputs(self):
|
||||||
|
r"""
|
||||||
|
Overriding the test_attention_outputs test as the Zamba2 model outputs attention only for its attention layers
|
||||||
|
"""
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
config.return_dict = True
|
||||||
|
|
||||||
|
seq_len = getattr(self.model_tester, "seq_length", None)
|
||||||
|
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
|
||||||
|
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
inputs_dict["output_attentions"] = True
|
||||||
|
inputs_dict["output_hidden_states"] = False
|
||||||
|
config.return_dict = True
|
||||||
|
model = model_class(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
attentions = outputs.attentions
|
||||||
|
|
||||||
|
# check that output_attentions also work using config
|
||||||
|
del inputs_dict["output_attentions"]
|
||||||
|
config.output_attentions = True
|
||||||
|
model = model_class(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
attentions = outputs.attentions
|
||||||
|
|
||||||
|
self.assertListEqual(
|
||||||
|
list(attentions[0].shape[-3:]),
|
||||||
|
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
||||||
|
)
|
||||||
|
out_len = len(outputs)
|
||||||
|
|
||||||
|
# Check attention is always last and order is fine
|
||||||
|
inputs_dict["output_attentions"] = True
|
||||||
|
inputs_dict["output_hidden_states"] = True
|
||||||
|
model = model_class(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
|
||||||
|
added_hidden_states = 1
|
||||||
|
self.assertEqual(out_len + added_hidden_states, len(outputs))
|
||||||
|
|
||||||
|
self_attentions = outputs.attentions
|
||||||
|
|
||||||
|
self.assertListEqual(
|
||||||
|
list(self_attentions[0].shape[-3:]),
|
||||||
|
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_input_ids_and_config(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
(
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
input_mask,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
) = config_and_inputs
|
||||||
|
return config, input_ids, input_mask
|
||||||
|
|
||||||
|
def test_left_padding_compatibility(self):
|
||||||
|
r"""
|
||||||
|
Overriding the test_left_padding_compatibility test as the mamba layers accentuate the numerical differences
|
||||||
|
effect of the left padding discussed in the issue in the note. Using a more permissive tolerance value.
|
||||||
|
"""
|
||||||
|
import inspect
|
||||||
|
# NOTE: left-padding results in small numerical differences. This is expected.
|
||||||
|
# See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535
|
||||||
|
|
||||||
|
# First, filter out models that don't support left padding - generative and decoder-only.
|
||||||
|
# Zamba2 is a decoder-only architecture
|
||||||
|
decoder_only_classes = self.all_generative_model_classes
|
||||||
|
|
||||||
|
# Then, test left-padding
|
||||||
|
def _prepare_model_kwargs(input_ids, attention_mask, signature):
|
||||||
|
model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask}
|
||||||
|
if "position_ids" in signature:
|
||||||
|
position_ids = torch.cumsum(attention_mask, dim=-1) - 1
|
||||||
|
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||||
|
model_kwargs["position_ids"] = position_ids
|
||||||
|
if "cache_position" in signature:
|
||||||
|
cache_position = torch.arange(input_ids.shape[-1], device=torch_device)
|
||||||
|
model_kwargs["cache_position"] = cache_position
|
||||||
|
return model_kwargs
|
||||||
|
|
||||||
|
for model_class in decoder_only_classes:
|
||||||
|
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||||
|
model = model_class(config).to(torch_device).eval()
|
||||||
|
signature = inspect.signature(model.forward).parameters.keys()
|
||||||
|
|
||||||
|
# Without padding
|
||||||
|
model_kwargs = _prepare_model_kwargs(input_ids, attention_mask, signature)
|
||||||
|
next_logits_wo_padding = model(**model_kwargs).logits[:, -1, :]
|
||||||
|
|
||||||
|
# With left-padding (length 32)
|
||||||
|
pad_size = (input_ids.shape[0], 32)
|
||||||
|
padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * config.pad_token_id
|
||||||
|
padded_input_ids = torch.cat((padding, input_ids), dim=1)
|
||||||
|
padded_attention_mask = torch.cat((torch.zeros_like(padding), attention_mask), dim=1)
|
||||||
|
model_kwargs = _prepare_model_kwargs(padded_input_ids, padded_attention_mask, signature)
|
||||||
|
next_logits_with_padding = model(**model_kwargs).logits[:, -1, :]
|
||||||
|
|
||||||
|
# They should result in very similar logits
|
||||||
|
self.assertTrue(torch.allclose(next_logits_wo_padding, next_logits_with_padding, atol=3e-3))
|
||||||
|
|
||||||
|
@require_flash_attn
|
||||||
|
@require_torch_gpu
|
||||||
|
@require_bitsandbytes
|
||||||
|
@pytest.mark.flash_attn_test
|
||||||
|
@slow
|
||||||
|
def test_flash_attn_2_fp32_ln(self):
|
||||||
|
r"""
|
||||||
|
Overriding the test_flash_attn_2_fp32_ln test as the Zamba2 model, like Mixtral, doesn't support
|
||||||
|
right padding + use cache with FA2
|
||||||
|
"""
|
||||||
|
for model_class in self.all_generative_model_classes:
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
model = model_class(config)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
model.save_pretrained(tmpdirname)
|
||||||
|
|
||||||
|
dummy_input = inputs_dict[model.main_input_name]
|
||||||
|
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
|
||||||
|
# NOTE: Zamba2 does not support right padding + use_cache with FA2.
|
||||||
|
dummy_attention_mask[:, -1] = 1
|
||||||
|
|
||||||
|
model = model_class.from_pretrained(
|
||||||
|
tmpdirname,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
attn_implementation="flash_attention_2",
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
load_in_4bit=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
for _, param in model.named_parameters():
|
||||||
|
# upcast only layer norms
|
||||||
|
if (param.dtype == torch.float16) or (param.dtype == torch.bfloat16):
|
||||||
|
param.data = param.data.to(torch.float32)
|
||||||
|
|
||||||
|
_ = model(dummy_input)
|
||||||
|
# with attention mask
|
||||||
|
_ = model(dummy_input, attention_mask=dummy_attention_mask)
|
||||||
|
|
||||||
|
@require_flash_attn
|
||||||
|
@require_torch_gpu
|
||||||
|
@pytest.mark.flash_attn_test
|
||||||
|
@slow
|
||||||
|
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
||||||
|
r"""
|
||||||
|
Overriding the test_flash_attn_2_inference_padding_right test as the Zamba2 model, like Mixtral, doesn't support
|
||||||
|
right padding + use cache with FA2
|
||||||
|
"""
|
||||||
|
self.skipTest(reason="Zamba2 flash attention does not support right padding")
|
||||||
|
|
||||||
|
@unittest.skip(reason="Zamba2 has its own special cache type")
|
||||||
|
@parameterized.expand([(1, False), (1, True), (4, False)])
|
||||||
|
def test_new_cache_format(self, num_beams, do_sample):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class Zamba2ModelIntegrationTest(unittest.TestCase):
|
||||||
|
model = None
|
||||||
|
tokenizer = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@slow
|
||||||
|
def setUpClass(cls):
|
||||||
|
model_id = "Zyphra/Zamba2-1.2B"
|
||||||
|
cls.model = Zamba2ForCausalLM.from_pretrained(
|
||||||
|
model_id, torch_dtype=torch.float32, low_cpu_mem_usage=True, revision="PR"
|
||||||
|
)
|
||||||
|
cls.tokenizer = AutoTokenizer.from_pretrained(model_id, revision="PR")
|
||||||
|
|
||||||
|
@parameterized.expand([(torch_device,), ("cpu",)])
|
||||||
|
@slow
|
||||||
|
def test_simple_generate(self, torch_device):
|
||||||
|
self.model.to(torch_device)
|
||||||
|
|
||||||
|
input_ids = self.tokenizer("Hey how are you doing on this lovely evening?", return_tensors="pt")[
|
||||||
|
"input_ids"
|
||||||
|
].to(torch_device)
|
||||||
|
out = self.model.generate(input_ids, do_sample=False, max_new_tokens=10)
|
||||||
|
output_sentence = self.tokenizer.decode(out[0, :])
|
||||||
|
self.assertEqual(
|
||||||
|
output_sentence,
|
||||||
|
"<s> Hey how are you doing on this lovely evening?\n\nI'm doing well, thanks for",
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = self.model(input_ids=input_ids).logits.to(dtype=torch.float32)
|
||||||
|
|
||||||
|
EXPECTED_LOGITS_NO_GRAD = torch.tensor(
|
||||||
|
[
|
||||||
|
-5.9587, 10.5152, 7.0382, -2.8728, -4.8143, -4.8142, -4.8142, -4.8144,
|
||||||
|
-4.8143, -4.8143, -4.8142, -4.8142, 6.0185, 18.0037, -4.8142, -4.8144,
|
||||||
|
-4.8143, -4.8142, -4.8143, -4.8143, -4.8143, -4.8143, -4.8142, -4.8143,
|
||||||
|
-4.8144, -4.8143, -4.8143, -4.8141, -4.8142, -4.8142, -4.8142, -4.8144,
|
||||||
|
-4.8143, -4.8143, -4.8143, -4.8142, -4.8144, -4.8144, -4.8142, -4.8142
|
||||||
|
]
|
||||||
|
, dtype=torch.float32) # fmt: skip
|
||||||
|
torch.testing.assert_close(logits[0, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD, rtol=1e-3, atol=1e-3)
|
||||||
|
|
||||||
|
@parameterized.expand([(torch_device,), ("cpu",)])
|
||||||
|
@slow
|
||||||
|
def test_simple_batched_generate_with_padding(self, torch_device):
|
||||||
|
self.model.to(torch_device)
|
||||||
|
|
||||||
|
inputs = self.tokenizer(
|
||||||
|
["Hey how are you doing on this lovely evening?", "When did the Roman empire "],
|
||||||
|
padding=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
).to(torch_device)
|
||||||
|
out = self.model.generate(**inputs, do_sample=False, max_new_tokens=10)
|
||||||
|
output_sentences = self.tokenizer.batch_decode(out)
|
||||||
|
self.assertEqual(
|
||||||
|
output_sentences[0],
|
||||||
|
"<s> Hey how are you doing on this lovely evening?\n\nI'm doing well, thanks for",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
output_sentences[1],
|
||||||
|
"[PAD][PAD][PAD][PAD]<s> When did the Roman empire 1st fall?\nThe Roman Empire fell in",
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = self.model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"]).logits.to(
|
||||||
|
dtype=torch.float32
|
||||||
|
)
|
||||||
|
|
||||||
|
EXPECTED_LOGITS_NO_GRAD_0 = torch.tensor(
|
||||||
|
[
|
||||||
|
-5.9611, 10.5208, 7.0411, -2.8743, -4.8167, -4.8167, -4.8167, -4.8168,
|
||||||
|
-4.8167, -4.8167, -4.8167, -4.8166, 6.0218, 18.0062, -4.8167, -4.8168,
|
||||||
|
-4.8167, -4.8167, -4.8167, -4.8168, -4.8168, -4.8168, -4.8167, -4.8167,
|
||||||
|
-4.8168, -4.8167, -4.8167, -4.8165, -4.8167, -4.8167, -4.8167, -4.8169,
|
||||||
|
-4.8168, -4.8168, -4.8168, -4.8166, -4.8169, -4.8168, -4.8167, -4.8167
|
||||||
|
]
|
||||||
|
, dtype=torch.float32) # fmt: skip
|
||||||
|
|
||||||
|
EXPECTED_LOGITS_NO_GRAD_1 = torch.tensor(
|
||||||
|
[
|
||||||
|
0.1966, 6.3449, 3.8350, -5.7291, -6.5106, -6.5104, -6.5103, -6.5104,
|
||||||
|
-6.5103, -6.5104, -6.5106, -6.5105, 7.8700, 13.5434, -6.5104, -6.5096,
|
||||||
|
-6.5106, -6.5102, -6.5106, -6.5106, -6.5105, -6.5106, -6.5104, -6.5106,
|
||||||
|
-6.5105, -6.5106, -6.5106, -6.5113, -6.5102, -6.5105, -6.5108, -6.5105,
|
||||||
|
-6.5104, -6.5106, -6.5106, -6.5104, -6.5106, -6.5107, -6.5103, -6.5105 ]
|
||||||
|
, dtype=torch.float32) # fmt: skip
|
||||||
|
|
||||||
|
torch.testing.assert_close(logits[0, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD_0, rtol=1e-3, atol=1e-3)
|
||||||
|
torch.testing.assert_close(
|
||||||
|
logits[1, -1, :40].cpu(),
|
||||||
|
EXPECTED_LOGITS_NO_GRAD_1,
|
||||||
|
rtol=1e-3,
|
||||||
|
atol=6e-3 if torch_device == "cpu" else 1e-3,
|
||||||
|
)
|
Loading…
Reference in New Issue
Block a user