mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
[MODEL] Add Falcon H1 (#38249)
* Create push-important-models.yml * feat: add falcon-h1 * fixup * address comment * fix * fix copies * fix copies * fix * fix * fix * fix * fix copies * fix * fix copies * fix test import to at least trigget the cis * yups * update * fix make fix copies * fix inits? * fix style * skip annoying test * add integration test for Falcon H1 * fix copies * fix --------- Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com> Co-authored-by: dhia.rhaiem <dhia.rhaiem@tii.ae>
This commit is contained in:
parent
e288ee00d8
commit
6829936ee0
@ -455,6 +455,8 @@
|
||||
title: Falcon
|
||||
- local: model_doc/falcon3
|
||||
title: Falcon3
|
||||
- local: model_doc/falcon_h1
|
||||
title: FalconH1
|
||||
- local: model_doc/falcon_mamba
|
||||
title: FalconMamba
|
||||
- local: model_doc/flan-t5
|
||||
|
65
docs/source/en/model_doc/falcon_h1.md
Normal file
65
docs/source/en/model_doc/falcon_h1.md
Normal file
@ -0,0 +1,65 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
-->
|
||||
|
||||
# FalconH1
|
||||
|
||||
## Overview
|
||||
|
||||
The FalconH1 model was developed by the TII Pretraining team. A comprehensive research paper covering the architecture, pretraining dynamics, experimental results, and conclusions is forthcoming. You can read more about this series in [this website](https://github.com/tiiuae/Falcon-H1).
|
||||
|
||||
## Contributors
|
||||
|
||||
This model was contributed by [DhiyaEddine](https://huggingface.co/DhiyaEddine), [ybelkada](https://huggingface.co/ybelkada), [JingweiZuo](https://huggingface.co/JingweiZuo), [IlyasChahed](https://huggingface.co/IChahed), and [MaksimVelikanov](https://huggingface.co/yellowvm).
|
||||
The original code can be found [here](https://github.com/tiiuae/Falcon-H1).
|
||||
|
||||
|
||||
## FalconH1Config
|
||||
|
||||
| Model | Depth | Dim | Attn Heads | KV | Mamba Heads | d_head | d_state | Ctx Len |
|
||||
|-----------|--------|------|------------|----|--------------|--------------|------|-----------------|
|
||||
| H1 0.5B | 36 | 1024 | 8 | 2 | 24 | 64 / 64 | 128 | 4K, 16K-SFT |
|
||||
| H1 1.5B | 24 | 2048 | 8 | 2 | 48 | 128 / 64 | 256 | 128K |
|
||||
| H1 1.5B-d | 66 | 1280 | 6 | 2 | 24 | 128 / 64 | 256 | 128K |
|
||||
| H1 3B | 32 | 2560 | 10 | 2 | 32 | 128 / 128 | 256 | 128K |
|
||||
| H1 7B | 44 | 3072 | 12 | 2 | 24 | 128 / 128 | 256 | 256K |
|
||||
| H1 34B | 72 | 5120 | 20 | 4 | 32 | 128 / 128 | 256 | 256K |
|
||||
|
||||
|
||||
|
||||
[[autodoc]] FalconH1Config
|
||||
|
||||
<!---
|
||||
## Usage Tips
|
||||
Tips:
|
||||
- The architecture is based on Mamba-2 models.
|
||||
## FalconH1Model
|
||||
[[autodoc]] FalconH1Model
|
||||
- forward
|
||||
-->
|
||||
|
||||
## FalconH1ForCausalLM
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("tiiuae/Falcon-H1-7B-Instruct")
|
||||
tokenizer = AutoTokenizer.from_pretrained("tiiuae/Falcon-H1-7B-Instruct")
|
||||
|
||||
message = ["Mamba is a snake with following properties "]
|
||||
inputs = tokenizer(message, return_tensors='pt', return_token_type_ids=False)
|
||||
response = model.generate(**inputs, max_new_tokens=64)
|
||||
print(tokenizer.batch_decode(response, skip_special_tokens=True)[0])
|
||||
```
|
||||
|
||||
[[autodoc]] FalconH1ForCausalLM
|
||||
- forward
|
||||
|
||||
This HF implementation is contributed by [younesbelkada](https://github.com/younesbelkada) and [DhiaEddineRhaiem](https://github.com/dhiaEddineRhaiem).
|
@ -1985,7 +1985,9 @@ class GenerationMixin:
|
||||
instantiated, writes it to `model_kwargs`, under the name expected by the model.
|
||||
"""
|
||||
|
||||
cache_name = "past_key_values" if "mamba" not in self.__class__.__name__.lower() else "cache_params"
|
||||
is_hybrid_cache = any(class_name in self.__class__.__name__.lower() for class_name in ["mamba", "falconh1"])
|
||||
cache_name = "past_key_values" if not is_hybrid_cache else "cache_params"
|
||||
|
||||
requires_cross_attention_cache = (
|
||||
self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None
|
||||
)
|
||||
|
@ -103,6 +103,7 @@ if TYPE_CHECKING:
|
||||
from .ernie import *
|
||||
from .esm import *
|
||||
from .falcon import *
|
||||
from .falcon_h1 import *
|
||||
from .falcon_mamba import *
|
||||
from .fastspeech2_conformer import *
|
||||
from .flaubert import *
|
||||
|
@ -118,6 +118,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
("ernie_m", "ErnieMConfig"),
|
||||
("esm", "EsmConfig"),
|
||||
("falcon", "FalconConfig"),
|
||||
("falcon_h1", "FalconH1Config"),
|
||||
("falcon_mamba", "FalconMambaConfig"),
|
||||
("fastspeech2_conformer", "FastSpeech2ConformerConfig"),
|
||||
("flaubert", "FlaubertConfig"),
|
||||
@ -481,6 +482,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("esm", "ESM"),
|
||||
("falcon", "Falcon"),
|
||||
("falcon3", "Falcon3"),
|
||||
("falcon_h1", "FalconH1"),
|
||||
("falcon_mamba", "FalconMamba"),
|
||||
("fastspeech2_conformer", "FastSpeech2Conformer"),
|
||||
("flan-t5", "FLAN-T5"),
|
||||
|
@ -115,6 +115,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("ernie_m", "ErnieMModel"),
|
||||
("esm", "EsmModel"),
|
||||
("falcon", "FalconModel"),
|
||||
("falcon_h1", "FalconH1Model"),
|
||||
("falcon_mamba", "FalconMambaModel"),
|
||||
("fastspeech2_conformer", "FastSpeech2ConformerModel"),
|
||||
("flaubert", "FlaubertModel"),
|
||||
@ -558,6 +559,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
("emu3", "Emu3ForCausalLM"),
|
||||
("ernie", "ErnieForCausalLM"),
|
||||
("falcon", "FalconForCausalLM"),
|
||||
("falcon_h1", "FalconH1ForCausalLM"),
|
||||
("falcon_mamba", "FalconMambaForCausalLM"),
|
||||
("fuyu", "FuyuForCausalLM"),
|
||||
("gemma", "GemmaForCausalLM"),
|
||||
|
27
src/transformers/models/falcon_h1/__init__.py
Normal file
27
src/transformers/models/falcon_h1/__init__.py
Normal file
@ -0,0 +1,27 @@
|
||||
# Copyright 2025 TII and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import _LazyModule
|
||||
from ...utils.import_utils import define_import_structure
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_falcon_h1 import *
|
||||
from .modeling_falcon_h1 import *
|
||||
else:
|
||||
import sys
|
||||
|
||||
_file = globals()["__file__"]
|
||||
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
283
src/transformers/models/falcon_h1/configuration_falcon_h1.py
Normal file
283
src/transformers/models/falcon_h1/configuration_falcon_h1.py
Normal file
@ -0,0 +1,283 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 TII and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""FalconH1 model configuration"""
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class FalconH1Config(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`FalconH1Model`]. It is used to instantiate a
|
||||
FalconH1Model model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with defaults taken from [ibm-fms/FalconH1-9.8b-2.2T-hf](https://huggingface.co/ibm-fms/FalconH1-9.8b-2.2T-hf).
|
||||
The FalconH1Model is a hybrid [mamba2](https://github.com/state-spaces/mamba) architecture with SwiGLU.
|
||||
The checkpoints are jointly trained by IBM, Princeton, and UIUC.
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 128000):
|
||||
Vocabulary size of the FalconH1 model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`FalconH1Model`]
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the
|
||||
model has a output word embedding layer.
|
||||
hidden_size (`int`, *optional*, defaults to 4096):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 14336):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 32):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
num_key_value_heads (`int`, *optional*, defaults to 8):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details checkout [this
|
||||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
|
||||
The epsilon used by the rms normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
num_logits_to_keep (`int` or `None`, *optional*, defaults to 1):
|
||||
Number of prompt logits to calculate during generation. If `None`, all logits will be calculated. If an
|
||||
integer value, only last `num_logits_to_keep` logits will be calculated. Default is 1 because only the
|
||||
logits of the last prompt token are needed for generation. For long sequences, the logits for the entire
|
||||
sequence may use a lot of memory so, setting `num_logits_to_keep=1` will reduce memory footprint
|
||||
significantly.
|
||||
pad_token_id (`int`, *optional*, defaults to 0):
|
||||
The id of the padding token.
|
||||
bos_token_id (`int`, *optional*, defaults to 1):
|
||||
The id of the "beginning-of-sequence" token.
|
||||
eos_token_id (`int`, *optional*, defaults to 2):
|
||||
The id of the "end-of-sequence" token.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 8192):
|
||||
Max cached sequence length for the model
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
mamba_d_ssm (`int`, *optional*, defaults to 1024):
|
||||
The dimension of the SSM state space latents.
|
||||
mamba_n_heads (`int`, *optional*, defaults to 128):
|
||||
The number of mamba heads used in the v2 implementation.
|
||||
mamba_d_head (`int`, *optional*, defaults to `"auto"`):
|
||||
Head embeddding dimension size
|
||||
mamba_n_groups (`int`, *optional*, defaults to 1):
|
||||
The number of the mamba groups used in the v2 implementation.
|
||||
mamba_d_state (`int`, *optional*, defaults to 256):
|
||||
The dimension the mamba state space latents
|
||||
mamba_d_conv (`int`, *optional*, defaults to 4):
|
||||
The size of the mamba convolution kernel
|
||||
mamba_expand (`int`, *optional*, defaults to 2):
|
||||
Expanding factor (relative to hidden_size) used to determine the mamba intermediate size
|
||||
mamba_chunk_size (`int`, *optional*, defaults to 256):
|
||||
The chunks in which to break the sequence when doing prefill/training
|
||||
mamba_conv_bias (`bool`, *optional*, defaults to `True`):
|
||||
Flag indicating whether or not to use bias in the convolution layer of the mamba mixer block.
|
||||
mamba_proj_bias (`bool`, *optional*, defaults to `False`):
|
||||
Flag indicating whether or not to use bias in the input and output projections (["in_proj", "out_proj"]) of the mamba mixer block
|
||||
mamba_norm_before_gate (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use RMSNorm before the gate in the Mamba block
|
||||
mamba_rms_norm (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use RMSNorm instead of LayerNorm in the Mamba block
|
||||
projectors_bias (`bool`, *optional*, defaults to `False`):
|
||||
Flag indicating whether or not to use bias in the input and output projections (["in_proj", "out_proj"]) of the attention block
|
||||
rope_theta (`float`, *optional*, defaults to 100000.0):
|
||||
The theta value used for the RoPE embeddings.
|
||||
rope_scaling (`float`, *optional*):
|
||||
The scaling value used for the RoPE embeddings. If `None`, no scaling is applied.
|
||||
lm_head_multiplier (`float`, *optional*, defaults to 1.0):
|
||||
The multiplier for the LM head. This is used to scale the output of the LM head.
|
||||
embedding_multiplier (`float`, *optional*, defaults to 1.0):
|
||||
The multiplier for the embedding layer. This is used to scale the output of the embedding layer.
|
||||
mlp_multipliers (`List[float]`, *optional*):
|
||||
The multipliers for the MLP layers. This is used to scale the output of the MLP layers. The first value is
|
||||
the multiplier of gate layer, the second value is the multiplier of the down_proj layer.
|
||||
key_multiplier (`float`, *optional*):
|
||||
The multiplier for the key layer. This is used to scale the output of the key layer.
|
||||
attention_out_multiplier (`float`, *optional*):
|
||||
The multiplier for the attention output layer. This is used to scale the output of the attention output
|
||||
attention_in_multiplier (`float`, *optional*):
|
||||
The multiplier for the attention input layer. This is used to scale the output of the attention input layer.
|
||||
ssm_multipliers (`List[float]`, *optional*):
|
||||
The multipliers for the SSM layers. This is used to scale the output of the SSM layers.
|
||||
ssm_in_multiplier (`float`, *optional*):
|
||||
The multiplier for the SSM input layer. This is used to scale the output of the SSM input layer.
|
||||
ssm_out_multiplier (`float`, *optional*):
|
||||
The multiplier for the SSM output layer. This is used to scale the output of the SSM output layer.
|
||||
"""
|
||||
|
||||
model_type = "falcon_h1"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=128000,
|
||||
tie_word_embeddings=False,
|
||||
hidden_size=4096,
|
||||
intermediate_size=14336,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=32,
|
||||
num_key_value_heads=8,
|
||||
hidden_act="silu",
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-5,
|
||||
use_cache=True,
|
||||
num_logits_to_keep=1,
|
||||
pad_token_id=0,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
max_position_embeddings=8192,
|
||||
attention_dropout=0.0,
|
||||
mamba_d_ssm=1024,
|
||||
mamba_n_heads=128,
|
||||
mamba_d_head="auto",
|
||||
mamba_n_groups=1,
|
||||
mamba_d_state=256,
|
||||
mamba_d_conv=4,
|
||||
mamba_expand=2,
|
||||
mamba_chunk_size=256,
|
||||
mamba_conv_bias=True,
|
||||
mamba_proj_bias=False,
|
||||
mamba_norm_before_gate=True,
|
||||
mamba_rms_norm=False,
|
||||
projectors_bias=False,
|
||||
rope_theta=100000.0,
|
||||
rope_scaling=None,
|
||||
lm_head_multiplier=1.0,
|
||||
embedding_multiplier=1.0,
|
||||
mlp_multipliers=None,
|
||||
key_multiplier=None,
|
||||
attention_out_multiplier=None,
|
||||
attention_in_multiplier=None,
|
||||
ssm_multipliers=None,
|
||||
ssm_in_multiplier=None,
|
||||
ssm_out_multiplier=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.attention_dropout = attention_dropout
|
||||
self.attention_bias = False
|
||||
self.mlp_bias = False
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
|
||||
self.use_cache = use_cache
|
||||
self.num_logits_to_keep = num_logits_to_keep
|
||||
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = None
|
||||
self.rope_scaling = rope_scaling
|
||||
self.projectors_bias = projectors_bias
|
||||
mamba_intermediate = mamba_expand * hidden_size if mamba_d_ssm is None else mamba_d_ssm
|
||||
|
||||
if mamba_intermediate % mamba_n_heads != 0:
|
||||
raise ValueError("mamba_n_heads must divide mamba_expand * hidden_size")
|
||||
|
||||
# for the mamba_v2, must satisfy the following
|
||||
if mamba_d_head == "auto":
|
||||
mamba_d_head = mamba_intermediate // mamba_n_heads
|
||||
|
||||
if mamba_d_head * mamba_n_heads != mamba_intermediate:
|
||||
raise ValueError("The dimensions for the Mamba head state do not match the model intermediate_size")
|
||||
|
||||
self.mamba_d_ssm = mamba_d_ssm
|
||||
self.mamba_n_heads = mamba_n_heads
|
||||
self.mamba_d_head = mamba_d_head
|
||||
self.mamba_n_groups = mamba_n_groups
|
||||
self.mamba_d_state = mamba_d_state
|
||||
self.mamba_d_conv = mamba_d_conv
|
||||
self.mamba_expand = mamba_expand
|
||||
self.mamba_chunk_size = mamba_chunk_size
|
||||
self.mamba_conv_bias = mamba_conv_bias
|
||||
self.mamba_proj_bias = mamba_proj_bias
|
||||
|
||||
self.mamba_norm_before_gate = mamba_norm_before_gate
|
||||
self.mamba_rms_norm = mamba_rms_norm
|
||||
|
||||
self.lm_head_multiplier = lm_head_multiplier
|
||||
self.embedding_multiplier = embedding_multiplier
|
||||
|
||||
if mlp_multipliers is not None:
|
||||
self.mlp_multipliers = mlp_multipliers
|
||||
else:
|
||||
self.mlp_multipliers = [1.0, 1.0]
|
||||
|
||||
if attention_out_multiplier is not None:
|
||||
self.attention_out_multiplier = attention_out_multiplier
|
||||
else:
|
||||
self.attention_out_multiplier = 1.0
|
||||
|
||||
if attention_in_multiplier is not None:
|
||||
self.attention_in_multiplier = attention_in_multiplier
|
||||
else:
|
||||
self.attention_in_multiplier = 1.0
|
||||
|
||||
if key_multiplier is not None:
|
||||
self.key_multiplier = key_multiplier
|
||||
else:
|
||||
self.key_multiplier = 1.0
|
||||
|
||||
if ssm_multipliers is not None:
|
||||
self.ssm_multipliers = ssm_multipliers
|
||||
else:
|
||||
#
|
||||
self.ssm_multipliers = [1.0, 1.0, 1.0, 1.0, 1.0]
|
||||
|
||||
if ssm_in_multiplier is not None:
|
||||
self.ssm_in_multiplier = ssm_in_multiplier
|
||||
else:
|
||||
self.ssm_in_multiplier = 1.0
|
||||
|
||||
if ssm_out_multiplier is not None:
|
||||
self.ssm_out_multiplier = ssm_out_multiplier
|
||||
else:
|
||||
self.ssm_out_multiplier = 1.0
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def layers_block_type(self):
|
||||
return ["attention" for i in range(self.num_hidden_layers)]
|
||||
|
||||
|
||||
__all__ = ["FalconH1Config"]
|
@ -0,0 +1,151 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 TII and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""This script can be used to convert checkpoints provided in the `mamba_ssm` library into the format provided in HuggingFace `transformers`. It depends on the `mamba2_ssm` package to be installed."""
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, FalconH1Config, FalconH1ForCausalLM
|
||||
|
||||
|
||||
CONVERSION_MAPPING = {
|
||||
"backbone": "model",
|
||||
"embeddings": "embed_tokens",
|
||||
"mixer.": "",
|
||||
"mixer_ssm": "mamba",
|
||||
"mixer_attn": "self_attn",
|
||||
"mlp.": "feed_forward.",
|
||||
"mlp_norm": "pre_ff_layernorm",
|
||||
"ssm_proj": "mamba.in_proj",
|
||||
"attn_out_proj": "o_proj",
|
||||
".norm.": ".input_layernorm.",
|
||||
".mamba.input_layernorm.": ".mamba.norm.",
|
||||
".ssm_out_proj.": ".mamba.out_proj.",
|
||||
"norm_f": "final_layernorm",
|
||||
}
|
||||
|
||||
|
||||
def convert_falcon_h1_to_hf(input_model_path, output_path):
|
||||
tokenizer = AutoTokenizer.from_pretrained(input_model_path)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
input_model_path, torch_dtype=torch.bfloat16, trust_remote_code=True, low_cpu_mem_usage=True
|
||||
)
|
||||
|
||||
intermediate_size = int(model.config.expansion_factor * model.config.hidden_size)
|
||||
|
||||
if intermediate_size % 2 != 0:
|
||||
intermediate_size = intermediate_size + (intermediate_size % 2)
|
||||
|
||||
new_config = FalconH1Config(
|
||||
vocab_size=model.config.vocab_size,
|
||||
tie_word_embeddings=model.config.tie_word_embeddings,
|
||||
hidden_size=model.config.hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
mamba_d_state=model.config.state_size,
|
||||
num_hidden_layers=model.config.num_hidden_layers,
|
||||
mamba_use_mlp=model.config.use_mlp,
|
||||
rms_norm_eps=model.config.layer_norm_epsilon,
|
||||
pad_token_id=model.config.pad_token_id,
|
||||
eos_token_id=model.config.eos_token_id,
|
||||
mamba_expand=model.config.expand,
|
||||
mamba_d_conv=model.config.conv_kernel,
|
||||
mamba_n_groups=model.config.n_groups,
|
||||
mamba_n_heads=model.config.num_heads,
|
||||
mamba_norm_before_gate=model.config.norm_before_gate,
|
||||
mamba_rms_norm=model.config.rms_norm,
|
||||
mamba_d_ssm=model.config.d_ssm,
|
||||
attention_bias=model.config.use_bias,
|
||||
projectors_bias=model.config.use_bias,
|
||||
mamba_conv_bias=model.config.use_conv_bias,
|
||||
hidden_act=model.config.hidden_act,
|
||||
use_cache=model.config.use_cache,
|
||||
mamba_chunk_size=model.config.chunk_size,
|
||||
num_attention_heads=model.config.num_heads_mha,
|
||||
num_key_value_heads=model.config.num_key_value_heads,
|
||||
head_dim=model.config.head_dim_mha,
|
||||
lm_head_multiplier=model.config.lm_head_multiplier,
|
||||
embedding_multiplier=model.config.embedding_multiplier,
|
||||
mlp_multipliers=model.config.mlp_multipliers,
|
||||
key_multiplier=model.config.key_multiplier,
|
||||
attention_out_multiplier=model.config.attention_out_multiplier,
|
||||
attention_in_multiplier=model.config.attention_in_multiplier,
|
||||
ssm_multipliers=model.config.ssm_multipliers,
|
||||
ssm_in_multiplier=model.config.ssm_in_multiplier,
|
||||
ssm_out_multiplier=model.config.ssm_out_multiplier,
|
||||
rope_theta=model.config.rope_theta,
|
||||
)
|
||||
|
||||
old_state_dict = model.state_dict()
|
||||
new_state_dict = {}
|
||||
|
||||
for old_key, old_value in old_state_dict.items():
|
||||
new_key = old_key
|
||||
for conversion_key, conversion_value in CONVERSION_MAPPING.items():
|
||||
if conversion_key in old_key:
|
||||
new_key = new_key.replace(conversion_key, conversion_value)
|
||||
|
||||
if "mamba.input_layernorm" in new_key:
|
||||
new_key = new_key.replace("mamba.input_layernorm", "mamba.norm")
|
||||
|
||||
# Special processing for attention layers
|
||||
if "self_attn.attn_proj" in new_key:
|
||||
num_heads = new_config.num_attention_heads
|
||||
num_kv_heads = new_config.num_key_value_heads
|
||||
head_dim = new_config.head_dim
|
||||
q_proj, k_proj, v_proj = old_value.split(
|
||||
[
|
||||
num_heads * head_dim,
|
||||
num_kv_heads * head_dim,
|
||||
num_kv_heads * head_dim,
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
new_state_dict[new_key.replace("attn_proj", "q_proj")] = q_proj
|
||||
new_state_dict[new_key.replace("attn_proj", "k_proj")] = k_proj
|
||||
new_state_dict[new_key.replace("attn_proj", "v_proj")] = v_proj
|
||||
else:
|
||||
new_state_dict[new_key] = old_value
|
||||
|
||||
with torch.device("meta"):
|
||||
new_model = FalconH1ForCausalLM(new_config)
|
||||
|
||||
del model
|
||||
|
||||
new_model.load_state_dict(new_state_dict, strict=True, assign=True)
|
||||
|
||||
new_model.save_pretrained(output_path)
|
||||
tokenizer.save_pretrained(output_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-i",
|
||||
"--mamba_ssm_checkpoint_directory",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to a directory containing the `pytorch_model.bin` mamba_ssm checkpoint file to be converted.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o", "--output_dir", type=str, required=True, help="Path to directory to save the converted output model to."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_falcon_h1_to_hf(
|
||||
args.mamba_ssm_checkpoint_directory,
|
||||
args.output_dir,
|
||||
)
|
1692
src/transformers/models/falcon_h1/modeling_falcon_h1.py
Normal file
1692
src/transformers/models/falcon_h1/modeling_falcon_h1.py
Normal file
File diff suppressed because it is too large
Load Diff
1442
src/transformers/models/falcon_h1/modular_falcon_h1.py
Normal file
1442
src/transformers/models/falcon_h1/modular_falcon_h1.py
Normal file
File diff suppressed because it is too large
Load Diff
0
tests/models/falcon_h1/__init__.py
Normal file
0
tests/models/falcon_h1/__init__.py
Normal file
518
tests/models/falcon_h1/test_modeling_falcon_h1.py
Normal file
518
tests/models/falcon_h1/test_modeling_falcon_h1.py
Normal file
@ -0,0 +1,518 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch FalconH1 model."""
|
||||
|
||||
import inspect
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from transformers import FalconH1Config, is_torch_available
|
||||
from transformers.testing_utils import (
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import AutoTokenizer, FalconH1ForCausalLM, FalconH1Model
|
||||
from transformers.models.falcon_h1.modeling_falcon_h1 import (
|
||||
FalconHybridMambaAttentionDynamicCache,
|
||||
)
|
||||
|
||||
|
||||
class FalconH1ModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_input_mask=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=4,
|
||||
num_attention_heads=4,
|
||||
num_key_value_heads=2,
|
||||
intermediate_size=64,
|
||||
hidden_act="silu",
|
||||
attention_dropout=0.0,
|
||||
attn_layer_indices=None,
|
||||
attn_rotary_emb=8,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=16,
|
||||
initializer_range=0.02,
|
||||
num_labels=3,
|
||||
pad_token_id=0,
|
||||
mamba_n_groups=1,
|
||||
mamba_n_heads=16,
|
||||
mamba_d_state=16,
|
||||
mamba_d_conv=4,
|
||||
mamba_expand=2,
|
||||
mamba_chunk_size=16,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_input_mask = use_input_mask
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.attention_dropout = attention_dropout
|
||||
self.attn_layer_indices = attn_layer_indices
|
||||
self.attn_rotary_emb = attn_rotary_emb
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.initializer_range = initializer_range
|
||||
self.num_labels = num_labels
|
||||
self.pad_token_id = pad_token_id
|
||||
self.scope = scope
|
||||
self.mamba_n_groups = mamba_n_groups
|
||||
self.mamba_n_heads = mamba_n_heads
|
||||
self.mamba_d_state = mamba_d_state
|
||||
self.mamba_d_conv = mamba_d_conv
|
||||
self.mamba_expand = mamba_expand
|
||||
self.mamba_chunk_size = mamba_chunk_size
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
||||
|
||||
token_labels = None
|
||||
if self.use_labels:
|
||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
return config, input_ids, input_mask, token_labels
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
input_mask,
|
||||
token_labels,
|
||||
) = config_and_inputs
|
||||
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
def get_config(self):
|
||||
# Fix for SDPA tests, force at least 4 layers
|
||||
if self.num_hidden_layers < 4:
|
||||
self.num_hidden_layers = 4
|
||||
if self.attn_layer_indices is None:
|
||||
d = [x for x in range(2, self.num_hidden_layers) if self.num_hidden_layers % x == 0]
|
||||
if len(d) == 0:
|
||||
raise ValueError("num_hidden_layers is prime, cannot automatically set attn_layer_indices.")
|
||||
d = d[-1] # get the largest divisor
|
||||
self.attn_layer_indices = [x + 1 for x in range(0, self.num_hidden_layers, d)]
|
||||
|
||||
return FalconH1Config(
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
num_key_value_heads=self.num_key_value_heads,
|
||||
intermediate_size=self.intermediate_size,
|
||||
hidden_act=self.hidden_act,
|
||||
attention_dropout=self.attention_dropout,
|
||||
attn_layer_indices=self.attn_layer_indices,
|
||||
attn_rotary_emb=self.attn_rotary_emb,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
initializer_range=self.initializer_range,
|
||||
pad_token_id=self.pad_token_id,
|
||||
mamba_n_groups=self.mamba_n_groups,
|
||||
mamba_n_heads=self.mamba_n_heads,
|
||||
mamba_d_state=self.mamba_d_state,
|
||||
mamba_d_conv=self.mamba_d_conv,
|
||||
mamba_expand=self.mamba_expand,
|
||||
mamba_chunk_size=self.mamba_chunk_size,
|
||||
)
|
||||
|
||||
def create_and_check_model(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
input_mask,
|
||||
token_labels,
|
||||
):
|
||||
model = FalconH1Model(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=input_mask)
|
||||
result = model(input_ids)
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
|
||||
def create_and_check_for_causal_lm(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
input_mask,
|
||||
token_labels,
|
||||
):
|
||||
model = FalconH1ForCausalLM(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
|
||||
result = model(input_ids, attention_mask=input_mask)
|
||||
result = model(input_ids, labels=token_labels)
|
||||
result = model(input_ids)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
def create_and_check_decoder_model_past_large_inputs(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
input_mask,
|
||||
token_labels,
|
||||
):
|
||||
# config.is_decoder = True
|
||||
# config.add_cross_attention = True
|
||||
model = FalconH1ForCausalLM(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
# first forward pass
|
||||
# Attention: Jamba needs the cache to be initialized to return a cache!
|
||||
past_key_values = FalconHybridMambaAttentionDynamicCache(
|
||||
config,
|
||||
input_ids.shape[0],
|
||||
model.dtype,
|
||||
devices=[model.device for _ in range(model.config.num_hidden_layers)],
|
||||
)
|
||||
outputs = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
)
|
||||
past_key_values = outputs.past_key_values
|
||||
|
||||
# create hypothetical multiple next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
|
||||
next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)
|
||||
|
||||
# append to next input_ids and
|
||||
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||
next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)
|
||||
|
||||
output_from_no_past = model(
|
||||
next_input_ids,
|
||||
attention_mask=next_attention_mask,
|
||||
output_hidden_states=True,
|
||||
)["hidden_states"][0]
|
||||
output_from_past = model(
|
||||
next_tokens,
|
||||
attention_mask=next_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
output_hidden_states=True,
|
||||
cache_position=torch.arange(
|
||||
input_ids.shape[1], input_ids.shape[1] + next_tokens.shape[1], device=model.device
|
||||
),
|
||||
)["hidden_states"][0]
|
||||
|
||||
# select random slice
|
||||
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
|
||||
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
|
||||
|
||||
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
|
||||
|
||||
# test that outputs are equal for slice
|
||||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||
|
||||
|
||||
@require_torch
|
||||
class FalconH1ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (FalconH1Model, FalconH1ForCausalLM) if is_torch_available() else ()
|
||||
test_headmasking = False
|
||||
test_pruning = False
|
||||
fx_compatible = False
|
||||
|
||||
# Need to use `0.8` instead of `0.9` for `test_cpu_offload`
|
||||
# This is because we are hitting edge cases with the causal_mask buffer
|
||||
model_split_percents = [0.5, 0.7, 0.8]
|
||||
|
||||
pipeline_model_mapping = (
|
||||
{"feature-extraction": FalconH1Model, "text-generation": FalconH1ForCausalLM} if is_torch_available() else {}
|
||||
)
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = FalconH1ModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=FalconH1Config, hidden_size=64)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_for_casual_lm(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_causal_lm(*config_and_inputs)
|
||||
|
||||
def test_decoder_model_past_with_large_inputs(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
|
||||
|
||||
# def test_initialization(self):
|
||||
# r"""
|
||||
# Overriding the test_initialization test as the A_log and D params of the FalconH1 mixer are initialized differently
|
||||
# """
|
||||
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
# configs_no_init = _config_zero_init(config)
|
||||
# for model_class in self.all_model_classes:
|
||||
# model = model_class(config=configs_no_init)
|
||||
# for name, param in model.named_parameters():
|
||||
# if param.requires_grad:
|
||||
# if "A_log" in name:
|
||||
# A = torch.arange(1, config.mamba_n_heads + 1, dtype=torch.float32)
|
||||
# torch.testing.assert_close(param.data, torch.log(A), rtol=1e-5, atol=1e-5)
|
||||
# elif "D" in name:
|
||||
# D = torch.ones(config.mamba_n_heads, dtype=torch.float32)
|
||||
# torch.testing.assert_close(param.data, D, rtol=1e-5, atol=1e-5)
|
||||
# else:
|
||||
# self.assertIn(
|
||||
# ((param.data.mean() * 1e9).round() / 1e9).item(),
|
||||
# [0.0, 1.0],
|
||||
# msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
# )
|
||||
|
||||
def test_mismatched_shapes_have_properly_initialized_weights(self):
|
||||
r"""
|
||||
Overriding the test_mismatched_shapes_have_properly_initialized_weights test because A_log and D params of the
|
||||
FalconH1 mixer are initialized differently and we tested that in test_initialization
|
||||
"""
|
||||
self.skipTest(reason="Cumbersome and redundant for FalconH1")
|
||||
|
||||
def test_attention_outputs(self):
|
||||
r"""
|
||||
Overriding the test_attention_outputs test as the FalconH1 model outputs attention only for its attention layers
|
||||
"""
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.return_dict = True
|
||||
|
||||
seq_len = getattr(self.model_tester, "seq_length", None)
|
||||
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
|
||||
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
|
||||
|
||||
expected_num_attentions = self.model_tester.num_hidden_layers
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_attentions"] = True
|
||||
inputs_dict["output_hidden_states"] = False
|
||||
config.return_dict = True
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
attentions = outputs.attentions
|
||||
self.assertEqual(len(attentions), expected_num_attentions)
|
||||
|
||||
# check that output_attentions also work using config
|
||||
del inputs_dict["output_attentions"]
|
||||
config.output_attentions = True
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
attentions = outputs.attentions
|
||||
self.assertEqual(len(attentions), expected_num_attentions)
|
||||
|
||||
self.assertListEqual(
|
||||
list(attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
||||
)
|
||||
out_len = len(outputs)
|
||||
|
||||
# Check attention is always last and order is fine
|
||||
inputs_dict["output_attentions"] = True
|
||||
inputs_dict["output_hidden_states"] = True
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
added_hidden_states = 1
|
||||
self.assertEqual(out_len + added_hidden_states, len(outputs))
|
||||
|
||||
self_attentions = outputs.attentions
|
||||
|
||||
self.assertEqual(len(self_attentions), expected_num_attentions)
|
||||
self.assertListEqual(
|
||||
list(self_attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
||||
)
|
||||
|
||||
def test_batching_equivalence(self):
|
||||
# need to disable the tril input mask
|
||||
orig = self.model_tester.use_input_mask
|
||||
self.model_tester.use_input_mask = False
|
||||
super().test_batching_equivalence()
|
||||
self.model_tester.use_input_mask = orig
|
||||
|
||||
# essentially the same test in test_utils, just adjustment for rtol for this model
|
||||
@pytest.mark.generate
|
||||
def test_left_padding_compatibility(self):
|
||||
# NOTE: left-padding results in small numerical differences. This is expected.
|
||||
# See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535
|
||||
|
||||
# First, filter out models that don't support left padding
|
||||
# - The model must have generative capabilities
|
||||
if len(self.all_generative_model_classes) == 0:
|
||||
self.skipTest(reason="No generative architecture available for this model.")
|
||||
|
||||
# - The model must support padding
|
||||
if not self.has_attentions:
|
||||
self.skipTest(reason="This model doesn't support padding.")
|
||||
|
||||
# - The model must be a decoder-only architecture (encoder-based architectures use right-padding)
|
||||
decoder_only_classes = []
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, _ = self.prepare_config_and_inputs_for_generate()
|
||||
if config.is_encoder_decoder:
|
||||
continue
|
||||
else:
|
||||
decoder_only_classes.append(model_class)
|
||||
if len(decoder_only_classes) == 0:
|
||||
self.skipTest(reason="No decoder-only architecture available for this model.")
|
||||
|
||||
# - Decoder-only architectures derived from encoder-decoder models could support it in theory, but we haven't
|
||||
# added support for it yet. We skip these models for now.
|
||||
has_encoder_attributes = any(
|
||||
attr_name
|
||||
for attr_name in config.to_dict().keys()
|
||||
if attr_name.startswith("encoder") and attr_name != "encoder_no_repeat_ngram_size"
|
||||
)
|
||||
if has_encoder_attributes:
|
||||
self.skipTest(
|
||||
reason="The decoder-only derived from encoder-decoder models are not expected to support left-padding."
|
||||
)
|
||||
|
||||
# Then, test left-padding
|
||||
def _prepare_model_kwargs(input_ids, attention_mask, signature):
|
||||
model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask}
|
||||
if "position_ids" in signature:
|
||||
position_ids = torch.cumsum(attention_mask, dim=-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
model_kwargs["position_ids"] = position_ids
|
||||
if "cache_position" in signature:
|
||||
cache_position = torch.arange(input_ids.shape[-1], device=torch_device)
|
||||
model_kwargs["cache_position"] = cache_position
|
||||
return model_kwargs
|
||||
|
||||
for model_class in decoder_only_classes:
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||
input_ids = inputs_dict["input_ids"]
|
||||
|
||||
# - for left padding we absolutely need to use an all ones
|
||||
# attention mask, so we do not use the one in inputs_dict
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
signature = inspect.signature(model.forward).parameters.keys()
|
||||
|
||||
# no cache as some models require special cache classes to be init outside forward
|
||||
model.generation_config.use_cache = False
|
||||
|
||||
# Without padding
|
||||
model_kwargs = _prepare_model_kwargs(input_ids, attention_mask, signature)
|
||||
next_logits_wo_padding = model(**model_kwargs).logits[:, -1, :]
|
||||
|
||||
# With left-padding (length 32)
|
||||
# can hardcode pad_token to be 0 as we'll do attn masking anyway
|
||||
pad_token_id = (
|
||||
config.get_text_config().pad_token_id if config.get_text_config().pad_token_id is not None else 0
|
||||
)
|
||||
pad_size = (input_ids.shape[0], 32)
|
||||
padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * pad_token_id
|
||||
padded_input_ids = torch.cat((padding, input_ids), dim=1)
|
||||
padded_attention_mask = torch.cat((torch.zeros_like(padding), attention_mask), dim=1)
|
||||
model_kwargs = _prepare_model_kwargs(padded_input_ids, padded_attention_mask, signature)
|
||||
next_logits_with_padding = model(**model_kwargs).logits[:, -1, :]
|
||||
|
||||
# They should result in very similar logits
|
||||
torch.testing.assert_close(next_logits_wo_padding, next_logits_with_padding, rtol=1e-5, atol=1e-5)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
@require_torch_gpu
|
||||
class FalconH1ModelIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_llama_3_1_hard(self):
|
||||
"""
|
||||
An integration test for Falcon-H1.
|
||||
"""
|
||||
EXPECTED_TEXT = (
|
||||
"Tell me about the french revolution.\n"
|
||||
"The French Revolution (1789–1799) was a period of radical social and political upheaval in France that "
|
||||
"fundamentally transformed the nation and had profound effects on the rest of Europe and the world. Here are the key aspects of the revolution:\n\n"
|
||||
"### **Causes**\n"
|
||||
"1. **Economic Crisis**: France was in severe financial trouble due to costly wars (particularly the American Revolution), extravagant spending by the monarchy, and inefficient taxation.\n"
|
||||
"2. **Social Inequality**: The rigid class system (the Ancien Régime) divided society into the privileged nobility and clergy (First Estate) and the common people (Third Estate), who bore the brunt of taxation and had few rights.\n"
|
||||
"3. **Enlightenment Ideas**: Philosophers like Rousseau, Voltaire, and Montesquieu inspired ideas of liberty, equality, and popular sovereignty.\n"
|
||||
"4. **Settlement of 1789**: The Estates-General convened to address the financial crisis, leading to the Third Estate's assertion of its rights and the eventual formation of the National Assembly.\n\n"
|
||||
"### **Key Events**\n"
|
||||
"1. **Opening of the Revolution (1789)**:\n"
|
||||
"- **Storming of the Bastille**: Symbolic of the fall of royal tyranny.\n"
|
||||
"- **Declaration of the Rights of Man and of the Citizen**: Proclaimed universal rights to liberty, property, and security.\n"
|
||||
"- **Creation of the National Assembly**: The Third Estate declared itself the representative body of France.\n\n"
|
||||
"2. **Radical Phase (1792–1794)**:\n"
|
||||
"- **Reign of Terror**: Led by Maximilien Robespierre, the Committee of Public Safety enforced radical egalitarianism through the guillotine, executing thousands of perceived enemies of the revolution (monarchists, clergy, aristocrats, and counter-revolutionaries).\n"
|
||||
"- **Execution of Louis XVI**: The king was guillotined in June 1793, symbolizing the end of the monarchy.\n"
|
||||
)
|
||||
|
||||
model_id = "tiiuae/Falcon-H1-1.5B-Deep-Instruct"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
model = FalconH1ForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")
|
||||
device = "cuda"
|
||||
messages = [{"role": "user", "content": "Tell me about the french revolution."}]
|
||||
input_text = tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model.generate(inputs, max_new_tokens=512, do_sample=False)
|
||||
|
||||
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
|
||||
self.assertEqual(generated_text, EXPECTED_TEXT)
|
Loading…
Reference in New Issue
Block a user