From 6829936ee0a32928f248f2d98341952c80f7572a Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Wed, 21 May 2025 12:43:11 +0400 Subject: [PATCH] [MODEL] Add Falcon H1 (#38249) * Create push-important-models.yml * feat: add falcon-h1 * fixup * address comment * fix * fix copies * fix copies * fix * fix * fix * fix * fix copies * fix * fix copies * fix test import to at least trigget the cis * yups * update * fix make fix copies * fix inits? * fix style * skip annoying test * add integration test for Falcon H1 * fix copies * fix --------- Co-authored-by: Arthur Zucker Co-authored-by: dhia.rhaiem --- docs/source/en/_toctree.yml | 2 + docs/source/en/model_doc/falcon_h1.md | 65 + src/transformers/generation/utils.py | 4 +- src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 2 + src/transformers/models/auto/modeling_auto.py | 2 + src/transformers/models/falcon_h1/__init__.py | 27 + .../falcon_h1/configuration_falcon_h1.py | 283 +++ .../falcon_h1/convert_mamba_ssm_checkpoint.py | 151 ++ .../models/falcon_h1/modeling_falcon_h1.py | 1692 +++++++++++++++++ .../models/falcon_h1/modular_falcon_h1.py | 1442 ++++++++++++++ tests/models/falcon_h1/__init__.py | 0 .../falcon_h1/test_modeling_falcon_h1.py | 518 +++++ 13 files changed, 4188 insertions(+), 1 deletion(-) create mode 100644 docs/source/en/model_doc/falcon_h1.md create mode 100644 src/transformers/models/falcon_h1/__init__.py create mode 100644 src/transformers/models/falcon_h1/configuration_falcon_h1.py create mode 100644 src/transformers/models/falcon_h1/convert_mamba_ssm_checkpoint.py create mode 100644 src/transformers/models/falcon_h1/modeling_falcon_h1.py create mode 100644 src/transformers/models/falcon_h1/modular_falcon_h1.py create mode 100644 tests/models/falcon_h1/__init__.py create mode 100644 tests/models/falcon_h1/test_modeling_falcon_h1.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 44c9a75aa79..873df4aa86a 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -455,6 +455,8 @@ title: Falcon - local: model_doc/falcon3 title: Falcon3 + - local: model_doc/falcon_h1 + title: FalconH1 - local: model_doc/falcon_mamba title: FalconMamba - local: model_doc/flan-t5 diff --git a/docs/source/en/model_doc/falcon_h1.md b/docs/source/en/model_doc/falcon_h1.md new file mode 100644 index 00000000000..96d2ea8decb --- /dev/null +++ b/docs/source/en/model_doc/falcon_h1.md @@ -0,0 +1,65 @@ + + +# FalconH1 + +## Overview + +The FalconH1 model was developed by the TII Pretraining team. A comprehensive research paper covering the architecture, pretraining dynamics, experimental results, and conclusions is forthcoming. You can read more about this series in [this website](https://github.com/tiiuae/Falcon-H1). + +## Contributors + +This model was contributed by [DhiyaEddine](https://huggingface.co/DhiyaEddine), [ybelkada](https://huggingface.co/ybelkada), [JingweiZuo](https://huggingface.co/JingweiZuo), [IlyasChahed](https://huggingface.co/IChahed), and [MaksimVelikanov](https://huggingface.co/yellowvm). +The original code can be found [here](https://github.com/tiiuae/Falcon-H1). + + +## FalconH1Config + +| Model | Depth | Dim | Attn Heads | KV | Mamba Heads | d_head | d_state | Ctx Len | +|-----------|--------|------|------------|----|--------------|--------------|------|-----------------| +| H1 0.5B | 36 | 1024 | 8 | 2 | 24 | 64 / 64 | 128 | 4K, 16K-SFT | +| H1 1.5B | 24 | 2048 | 8 | 2 | 48 | 128 / 64 | 256 | 128K | +| H1 1.5B-d | 66 | 1280 | 6 | 2 | 24 | 128 / 64 | 256 | 128K | +| H1 3B | 32 | 2560 | 10 | 2 | 32 | 128 / 128 | 256 | 128K | +| H1 7B | 44 | 3072 | 12 | 2 | 24 | 128 / 128 | 256 | 256K | +| H1 34B | 72 | 5120 | 20 | 4 | 32 | 128 / 128 | 256 | 256K | + + + +[[autodoc]] FalconH1Config + + + +## FalconH1ForCausalLM + +```python +from transformers import AutoModelForCausalLM, AutoTokenizer + +model = AutoModelForCausalLM.from_pretrained("tiiuae/Falcon-H1-7B-Instruct") +tokenizer = AutoTokenizer.from_pretrained("tiiuae/Falcon-H1-7B-Instruct") + +message = ["Mamba is a snake with following properties "] +inputs = tokenizer(message, return_tensors='pt', return_token_type_ids=False) +response = model.generate(**inputs, max_new_tokens=64) +print(tokenizer.batch_decode(response, skip_special_tokens=True)[0]) +``` + +[[autodoc]] FalconH1ForCausalLM + - forward + +This HF implementation is contributed by [younesbelkada](https://github.com/younesbelkada) and [DhiaEddineRhaiem](https://github.com/dhiaEddineRhaiem). \ No newline at end of file diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 02513e0d848..510f55d824b 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1985,7 +1985,9 @@ class GenerationMixin: instantiated, writes it to `model_kwargs`, under the name expected by the model. """ - cache_name = "past_key_values" if "mamba" not in self.__class__.__name__.lower() else "cache_params" + is_hybrid_cache = any(class_name in self.__class__.__name__.lower() for class_name in ["mamba", "falconh1"]) + cache_name = "past_key_values" if not is_hybrid_cache else "cache_params" + requires_cross_attention_cache = ( self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None ) diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 8d713b482ba..d2988a960b8 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -103,6 +103,7 @@ if TYPE_CHECKING: from .ernie import * from .esm import * from .falcon import * + from .falcon_h1 import * from .falcon_mamba import * from .fastspeech2_conformer import * from .flaubert import * diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 01c58a50629..bebf72e45f0 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -118,6 +118,7 @@ CONFIG_MAPPING_NAMES = OrderedDict( ("ernie_m", "ErnieMConfig"), ("esm", "EsmConfig"), ("falcon", "FalconConfig"), + ("falcon_h1", "FalconH1Config"), ("falcon_mamba", "FalconMambaConfig"), ("fastspeech2_conformer", "FastSpeech2ConformerConfig"), ("flaubert", "FlaubertConfig"), @@ -481,6 +482,7 @@ MODEL_NAMES_MAPPING = OrderedDict( ("esm", "ESM"), ("falcon", "Falcon"), ("falcon3", "Falcon3"), + ("falcon_h1", "FalconH1"), ("falcon_mamba", "FalconMamba"), ("fastspeech2_conformer", "FastSpeech2Conformer"), ("flan-t5", "FLAN-T5"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index a39da1ceda5..5fe74444cb2 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -115,6 +115,7 @@ MODEL_MAPPING_NAMES = OrderedDict( ("ernie_m", "ErnieMModel"), ("esm", "EsmModel"), ("falcon", "FalconModel"), + ("falcon_h1", "FalconH1Model"), ("falcon_mamba", "FalconMambaModel"), ("fastspeech2_conformer", "FastSpeech2ConformerModel"), ("flaubert", "FlaubertModel"), @@ -558,6 +559,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( ("emu3", "Emu3ForCausalLM"), ("ernie", "ErnieForCausalLM"), ("falcon", "FalconForCausalLM"), + ("falcon_h1", "FalconH1ForCausalLM"), ("falcon_mamba", "FalconMambaForCausalLM"), ("fuyu", "FuyuForCausalLM"), ("gemma", "GemmaForCausalLM"), diff --git a/src/transformers/models/falcon_h1/__init__.py b/src/transformers/models/falcon_h1/__init__.py new file mode 100644 index 00000000000..9749c5e1e98 --- /dev/null +++ b/src/transformers/models/falcon_h1/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2025 TII and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_falcon_h1 import * + from .modeling_falcon_h1 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/falcon_h1/configuration_falcon_h1.py b/src/transformers/models/falcon_h1/configuration_falcon_h1.py new file mode 100644 index 00000000000..2a686ecc612 --- /dev/null +++ b/src/transformers/models/falcon_h1/configuration_falcon_h1.py @@ -0,0 +1,283 @@ +# coding=utf-8 +# Copyright 2025 TII and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""FalconH1 model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class FalconH1Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`FalconH1Model`]. It is used to instantiate a + FalconH1Model model according to the specified arguments, defining the model architecture. Instantiating a configuration + with defaults taken from [ibm-fms/FalconH1-9.8b-2.2T-hf](https://huggingface.co/ibm-fms/FalconH1-9.8b-2.2T-hf). + The FalconH1Model is a hybrid [mamba2](https://github.com/state-spaces/mamba) architecture with SwiGLU. + The checkpoints are jointly trained by IBM, Princeton, and UIUC. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 128000): + Vocabulary size of the FalconH1 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`FalconH1Model`] + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the + model has a output word embedding layer. + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 14336): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + num_logits_to_keep (`int` or `None`, *optional*, defaults to 1): + Number of prompt logits to calculate during generation. If `None`, all logits will be calculated. If an + integer value, only last `num_logits_to_keep` logits will be calculated. Default is 1 because only the + logits of the last prompt token are needed for generation. For long sequences, the logits for the entire + sequence may use a lot of memory so, setting `num_logits_to_keep=1` will reduce memory footprint + significantly. + pad_token_id (`int`, *optional*, defaults to 0): + The id of the padding token. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the "end-of-sequence" token. + max_position_embeddings (`int`, *optional*, defaults to 8192): + Max cached sequence length for the model + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + mamba_d_ssm (`int`, *optional*, defaults to 1024): + The dimension of the SSM state space latents. + mamba_n_heads (`int`, *optional*, defaults to 128): + The number of mamba heads used in the v2 implementation. + mamba_d_head (`int`, *optional*, defaults to `"auto"`): + Head embeddding dimension size + mamba_n_groups (`int`, *optional*, defaults to 1): + The number of the mamba groups used in the v2 implementation. + mamba_d_state (`int`, *optional*, defaults to 256): + The dimension the mamba state space latents + mamba_d_conv (`int`, *optional*, defaults to 4): + The size of the mamba convolution kernel + mamba_expand (`int`, *optional*, defaults to 2): + Expanding factor (relative to hidden_size) used to determine the mamba intermediate size + mamba_chunk_size (`int`, *optional*, defaults to 256): + The chunks in which to break the sequence when doing prefill/training + mamba_conv_bias (`bool`, *optional*, defaults to `True`): + Flag indicating whether or not to use bias in the convolution layer of the mamba mixer block. + mamba_proj_bias (`bool`, *optional*, defaults to `False`): + Flag indicating whether or not to use bias in the input and output projections (["in_proj", "out_proj"]) of the mamba mixer block + mamba_norm_before_gate (`bool`, *optional*, defaults to `True`): + Whether to use RMSNorm before the gate in the Mamba block + mamba_rms_norm (`bool`, *optional*, defaults to `False`): + Whether to use RMSNorm instead of LayerNorm in the Mamba block + projectors_bias (`bool`, *optional*, defaults to `False`): + Flag indicating whether or not to use bias in the input and output projections (["in_proj", "out_proj"]) of the attention block + rope_theta (`float`, *optional*, defaults to 100000.0): + The theta value used for the RoPE embeddings. + rope_scaling (`float`, *optional*): + The scaling value used for the RoPE embeddings. If `None`, no scaling is applied. + lm_head_multiplier (`float`, *optional*, defaults to 1.0): + The multiplier for the LM head. This is used to scale the output of the LM head. + embedding_multiplier (`float`, *optional*, defaults to 1.0): + The multiplier for the embedding layer. This is used to scale the output of the embedding layer. + mlp_multipliers (`List[float]`, *optional*): + The multipliers for the MLP layers. This is used to scale the output of the MLP layers. The first value is + the multiplier of gate layer, the second value is the multiplier of the down_proj layer. + key_multiplier (`float`, *optional*): + The multiplier for the key layer. This is used to scale the output of the key layer. + attention_out_multiplier (`float`, *optional*): + The multiplier for the attention output layer. This is used to scale the output of the attention output + attention_in_multiplier (`float`, *optional*): + The multiplier for the attention input layer. This is used to scale the output of the attention input layer. + ssm_multipliers (`List[float]`, *optional*): + The multipliers for the SSM layers. This is used to scale the output of the SSM layers. + ssm_in_multiplier (`float`, *optional*): + The multiplier for the SSM input layer. This is used to scale the output of the SSM input layer. + ssm_out_multiplier (`float`, *optional*): + The multiplier for the SSM output layer. This is used to scale the output of the SSM output layer. + """ + + model_type = "falcon_h1" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=128000, + tie_word_embeddings=False, + hidden_size=4096, + intermediate_size=14336, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=8, + hidden_act="silu", + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + num_logits_to_keep=1, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + max_position_embeddings=8192, + attention_dropout=0.0, + mamba_d_ssm=1024, + mamba_n_heads=128, + mamba_d_head="auto", + mamba_n_groups=1, + mamba_d_state=256, + mamba_d_conv=4, + mamba_expand=2, + mamba_chunk_size=256, + mamba_conv_bias=True, + mamba_proj_bias=False, + mamba_norm_before_gate=True, + mamba_rms_norm=False, + projectors_bias=False, + rope_theta=100000.0, + rope_scaling=None, + lm_head_multiplier=1.0, + embedding_multiplier=1.0, + mlp_multipliers=None, + key_multiplier=None, + attention_out_multiplier=None, + attention_in_multiplier=None, + ssm_multipliers=None, + ssm_in_multiplier=None, + ssm_out_multiplier=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.max_position_embeddings = max_position_embeddings + self.attention_dropout = attention_dropout + self.attention_bias = False + self.mlp_bias = False + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + + self.use_cache = use_cache + self.num_logits_to_keep = num_logits_to_keep + + self.rope_theta = rope_theta + self.rope_scaling = None + self.rope_scaling = rope_scaling + self.projectors_bias = projectors_bias + mamba_intermediate = mamba_expand * hidden_size if mamba_d_ssm is None else mamba_d_ssm + + if mamba_intermediate % mamba_n_heads != 0: + raise ValueError("mamba_n_heads must divide mamba_expand * hidden_size") + + # for the mamba_v2, must satisfy the following + if mamba_d_head == "auto": + mamba_d_head = mamba_intermediate // mamba_n_heads + + if mamba_d_head * mamba_n_heads != mamba_intermediate: + raise ValueError("The dimensions for the Mamba head state do not match the model intermediate_size") + + self.mamba_d_ssm = mamba_d_ssm + self.mamba_n_heads = mamba_n_heads + self.mamba_d_head = mamba_d_head + self.mamba_n_groups = mamba_n_groups + self.mamba_d_state = mamba_d_state + self.mamba_d_conv = mamba_d_conv + self.mamba_expand = mamba_expand + self.mamba_chunk_size = mamba_chunk_size + self.mamba_conv_bias = mamba_conv_bias + self.mamba_proj_bias = mamba_proj_bias + + self.mamba_norm_before_gate = mamba_norm_before_gate + self.mamba_rms_norm = mamba_rms_norm + + self.lm_head_multiplier = lm_head_multiplier + self.embedding_multiplier = embedding_multiplier + + if mlp_multipliers is not None: + self.mlp_multipliers = mlp_multipliers + else: + self.mlp_multipliers = [1.0, 1.0] + + if attention_out_multiplier is not None: + self.attention_out_multiplier = attention_out_multiplier + else: + self.attention_out_multiplier = 1.0 + + if attention_in_multiplier is not None: + self.attention_in_multiplier = attention_in_multiplier + else: + self.attention_in_multiplier = 1.0 + + if key_multiplier is not None: + self.key_multiplier = key_multiplier + else: + self.key_multiplier = 1.0 + + if ssm_multipliers is not None: + self.ssm_multipliers = ssm_multipliers + else: + # + self.ssm_multipliers = [1.0, 1.0, 1.0, 1.0, 1.0] + + if ssm_in_multiplier is not None: + self.ssm_in_multiplier = ssm_in_multiplier + else: + self.ssm_in_multiplier = 1.0 + + if ssm_out_multiplier is not None: + self.ssm_out_multiplier = ssm_out_multiplier + else: + self.ssm_out_multiplier = 1.0 + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + @property + def layers_block_type(self): + return ["attention" for i in range(self.num_hidden_layers)] + + +__all__ = ["FalconH1Config"] diff --git a/src/transformers/models/falcon_h1/convert_mamba_ssm_checkpoint.py b/src/transformers/models/falcon_h1/convert_mamba_ssm_checkpoint.py new file mode 100644 index 00000000000..d3a8c4b8f5a --- /dev/null +++ b/src/transformers/models/falcon_h1/convert_mamba_ssm_checkpoint.py @@ -0,0 +1,151 @@ +# coding=utf-8 +# Copyright 2025 TII and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""This script can be used to convert checkpoints provided in the `mamba_ssm` library into the format provided in HuggingFace `transformers`. It depends on the `mamba2_ssm` package to be installed.""" + +import argparse + +import torch + +from transformers import AutoModelForCausalLM, AutoTokenizer, FalconH1Config, FalconH1ForCausalLM + + +CONVERSION_MAPPING = { + "backbone": "model", + "embeddings": "embed_tokens", + "mixer.": "", + "mixer_ssm": "mamba", + "mixer_attn": "self_attn", + "mlp.": "feed_forward.", + "mlp_norm": "pre_ff_layernorm", + "ssm_proj": "mamba.in_proj", + "attn_out_proj": "o_proj", + ".norm.": ".input_layernorm.", + ".mamba.input_layernorm.": ".mamba.norm.", + ".ssm_out_proj.": ".mamba.out_proj.", + "norm_f": "final_layernorm", +} + + +def convert_falcon_h1_to_hf(input_model_path, output_path): + tokenizer = AutoTokenizer.from_pretrained(input_model_path) + + model = AutoModelForCausalLM.from_pretrained( + input_model_path, torch_dtype=torch.bfloat16, trust_remote_code=True, low_cpu_mem_usage=True + ) + + intermediate_size = int(model.config.expansion_factor * model.config.hidden_size) + + if intermediate_size % 2 != 0: + intermediate_size = intermediate_size + (intermediate_size % 2) + + new_config = FalconH1Config( + vocab_size=model.config.vocab_size, + tie_word_embeddings=model.config.tie_word_embeddings, + hidden_size=model.config.hidden_size, + intermediate_size=intermediate_size, + mamba_d_state=model.config.state_size, + num_hidden_layers=model.config.num_hidden_layers, + mamba_use_mlp=model.config.use_mlp, + rms_norm_eps=model.config.layer_norm_epsilon, + pad_token_id=model.config.pad_token_id, + eos_token_id=model.config.eos_token_id, + mamba_expand=model.config.expand, + mamba_d_conv=model.config.conv_kernel, + mamba_n_groups=model.config.n_groups, + mamba_n_heads=model.config.num_heads, + mamba_norm_before_gate=model.config.norm_before_gate, + mamba_rms_norm=model.config.rms_norm, + mamba_d_ssm=model.config.d_ssm, + attention_bias=model.config.use_bias, + projectors_bias=model.config.use_bias, + mamba_conv_bias=model.config.use_conv_bias, + hidden_act=model.config.hidden_act, + use_cache=model.config.use_cache, + mamba_chunk_size=model.config.chunk_size, + num_attention_heads=model.config.num_heads_mha, + num_key_value_heads=model.config.num_key_value_heads, + head_dim=model.config.head_dim_mha, + lm_head_multiplier=model.config.lm_head_multiplier, + embedding_multiplier=model.config.embedding_multiplier, + mlp_multipliers=model.config.mlp_multipliers, + key_multiplier=model.config.key_multiplier, + attention_out_multiplier=model.config.attention_out_multiplier, + attention_in_multiplier=model.config.attention_in_multiplier, + ssm_multipliers=model.config.ssm_multipliers, + ssm_in_multiplier=model.config.ssm_in_multiplier, + ssm_out_multiplier=model.config.ssm_out_multiplier, + rope_theta=model.config.rope_theta, + ) + + old_state_dict = model.state_dict() + new_state_dict = {} + + for old_key, old_value in old_state_dict.items(): + new_key = old_key + for conversion_key, conversion_value in CONVERSION_MAPPING.items(): + if conversion_key in old_key: + new_key = new_key.replace(conversion_key, conversion_value) + + if "mamba.input_layernorm" in new_key: + new_key = new_key.replace("mamba.input_layernorm", "mamba.norm") + + # Special processing for attention layers + if "self_attn.attn_proj" in new_key: + num_heads = new_config.num_attention_heads + num_kv_heads = new_config.num_key_value_heads + head_dim = new_config.head_dim + q_proj, k_proj, v_proj = old_value.split( + [ + num_heads * head_dim, + num_kv_heads * head_dim, + num_kv_heads * head_dim, + ], + dim=0, + ) + new_state_dict[new_key.replace("attn_proj", "q_proj")] = q_proj + new_state_dict[new_key.replace("attn_proj", "k_proj")] = k_proj + new_state_dict[new_key.replace("attn_proj", "v_proj")] = v_proj + else: + new_state_dict[new_key] = old_value + + with torch.device("meta"): + new_model = FalconH1ForCausalLM(new_config) + + del model + + new_model.load_state_dict(new_state_dict, strict=True, assign=True) + + new_model.save_pretrained(output_path) + tokenizer.save_pretrained(output_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "-i", + "--mamba_ssm_checkpoint_directory", + type=str, + required=True, + help="Path to a directory containing the `pytorch_model.bin` mamba_ssm checkpoint file to be converted.", + ) + parser.add_argument( + "-o", "--output_dir", type=str, required=True, help="Path to directory to save the converted output model to." + ) + args = parser.parse_args() + + convert_falcon_h1_to_hf( + args.mamba_ssm_checkpoint_directory, + args.output_dir, + ) diff --git a/src/transformers/models/falcon_h1/modeling_falcon_h1.py b/src/transformers/models/falcon_h1/modeling_falcon_h1.py new file mode 100644 index 00000000000..5913a7f80bc --- /dev/null +++ b/src/transformers/models/falcon_h1/modeling_falcon_h1.py @@ -0,0 +1,1692 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/falcon_h1/modular_falcon_h1.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_falcon_h1.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 Technology Innovation Institute and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from transformers.activations import ACT2FN + +from ...cache_utils import ( + Cache, + DynamicCache, # we need __iter__ and __len__ of pkv +) +from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ( + auto_docstring, + is_torchdynamo_compiling, + logging, + replace_return_docstrings, +) +from ...utils.deprecation import deprecate_kwarg +from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available +from .configuration_falcon_h1 import FalconH1Config + + +if is_mamba_2_ssm_available(): + from mamba_ssm.ops.triton.selective_state_update import selective_state_update + from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined +else: + selective_state_update = None + +if is_causal_conv1d_available(): + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +else: + causal_conv1d_update, causal_conv1d_fn = None, None + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "FalconH1Config" + + +class FalconHybridMambaAttentionDynamicCache(DynamicCache): + """ + A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache + (which has a constant shape regardless of seq_len). + + This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` + and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor + For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, + while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). + For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), + while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, + and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. + """ + + def __init__( + self, + config: FalconH1Config, + batch_size: int, + dtype: torch.dtype = torch.float16, + devices: Optional[List[str]] = None, + ): + self.seqlen_offset = 0 + self.dtype = dtype + self.has_previous_state = False + self.conv_kernel_size = config.mamba_d_conv + + self._seen_tokens = 0 + + self.intermediate_size = ( + config.mamba_d_ssm if config.mamba_d_ssm is not None else int(config.mamba_expand * config.hidden_size) + ) + + self.conv_states = { + i: torch.zeros( + batch_size, + self.intermediate_size + 2 * config.mamba_n_groups * config.mamba_d_state, + self.conv_kernel_size, + device=devices[i], + dtype=dtype, + ) + for i in range(config.num_hidden_layers) + } + self.ssm_states = { + i: torch.zeros( + batch_size, + config.mamba_n_heads, + config.mamba_d_head, + config.mamba_d_state, + device=devices[i], + dtype=dtype, + ) + for i in range(config.num_hidden_layers) + } + + self.transformer_layers = [] + for i in range(config.num_hidden_layers): + self.transformer_layers.append(i) + + self.key_cache: List[torch.Tensor] = [] + self.value_cache: List[torch.Tensor] = [] + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. + + Return: + A tuple containing the updated key and value states. + """ + # Update the number of seen tokens + if layer_idx == 0: + self._seen_tokens += key_states.shape[-2] + + # Update the cache + if len(self.key_cache) <= layer_idx: + # There may be skipped layers, fill them with empty lists + for _ in range(len(self.key_cache), layer_idx): + self.key_cache.append([]) + self.value_cache.append([]) + self.key_cache.append(key_states) + self.value_cache.append(value_states) + elif len(self.key_cache[layer_idx]) == 0: # fills previously skipped layers; checking for tensor causes errors + self.key_cache[layer_idx] = key_states + self.value_cache[layer_idx] = value_states + else: + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + + device = self.conv_states[layer_idx].device + self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) + device = self.ssm_states[layer_idx].device + self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # take any layer that contains cache and not empty tensor + layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx + if len(self.key_cache) <= layer_idx: + return 0 + return self.key_cache[layer_idx].shape[-2] + + def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: + raise NotImplementedError("FalconHybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") + + @classmethod + def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache": + raise NotImplementedError("FalconHybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") + + def update_conv_state( + self, + layer_idx: int, + new_conv_state: torch.Tensor, + cache_position: torch.LongTensor, + ) -> torch.Tensor: + conv_state = self.conv_states[layer_idx] + cache_position = cache_position.clamp(0, self.conv_kernel_size - 1) + + conv_state = conv_state.roll(shifts=-1, dims=-1) + if len(cache_position) > 1: + conv_state[:, :, :] = new_conv_state.to(conv_state.device) + else: + conv_state[:, :, -1] = new_conv_state[:, :, -1].to(conv_state.device) + self.conv_states[layer_idx].zero_() + self.conv_states[layer_idx] += conv_state + return self.conv_states[layer_idx] + + def reset(self): + self.conv_states.zero_() + self.ssm_states.zero_() + + +class FalconH1RotaryEmbedding(nn.Module): + def __init__(self, config: FalconH1Config, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class FalconH1Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: FalconH1Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.key_multiplier = config.key_multiplier + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) * self.key_multiplier + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class FalconH1RMSNormGated(torch.nn.Module): + def __init__(self, hidden_size, eps=1e-6, n_groups=1, norm_before_gate=True): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + self.n_groups = n_groups + self.norm_before_gate = norm_before_gate + + def forward(self, hidden_states, gate=None): + input_dtype = hidden_states.dtype + + if not self.norm_before_gate and gate is not None: + hidden_states = hidden_states * F.silu(gate.to(torch.float32)) + + if len(hidden_states.shape) == 3: + batch_size, seq_len, dim = hidden_states.shape + else: + batch_size, dim = hidden_states.shape + seq_len = 1 + hidden_states = hidden_states.to(torch.float32) + + hidden_states = hidden_states.view(batch_size, seq_len, self.n_groups, int(dim // self.n_groups)) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + hidden_states = self.weight.view(self.n_groups, int(dim // self.n_groups)) * hidden_states + hidden_states = hidden_states.view(batch_size, seq_len, dim) + + if seq_len == 1: + hidden_states = hidden_states.squeeze(1) + + if self.norm_before_gate and gate is not None: + hidden_states = hidden_states * F.silu(gate.to(torch.float32)) + return hidden_states.to(input_dtype) + + +# Helper methods for segment sum computation + + +def pad_tensor_by_size(input_tensor: torch.Tensor, pad_size: int): + """ + Padding x tensor with `pad_size` on the seq_len dim (dim=1) + + Assumes that we only have tensors of either size 4 or 3 + """ + pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if len(input_tensor.shape) == 4 else (0, 0, 0, pad_size, 0, 0) + + return torch.nn.functional.pad(input_tensor, pad_shape, mode="constant", value=0) + + +def reshape_into_chunks(input_tensor, pad_size, chunk_size): + """ + Padding input_tensor with `pad_size` on the seq_len dim (dim=1) and + simultaneously splitting it into chunk sequences. + + Assumes that we only have tensors of either size 4 or 3 + """ + # [bsz, seq_len, ...] -> [bsz, seq_len multiple of chunk_size, ...] + input_tensor = pad_tensor_by_size(input_tensor, pad_size) + + if len(input_tensor.shape) == 3: + # [bsz, seq_len multiple of chunk_size, num_heads] -> [bsz, -1, chunk_size, num_heads] + return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2]) + else: + # [bsz, seq_len multiple of chunk_size, num_heads, head_dim or state_size] -> [bsz, -1, chunk_size, num_heads, head_dim or state_size] + return input_tensor.reshape( + input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2], input_tensor.shape[3] + ) + + +def segment_sum(input_tensor): + """ + More stable segment sum calculation. Uses cumulative sums and masking instead of direct subtractions. + """ + chunk_size = input_tensor.size(-1) + # 1. expand input tensor to have an additional dimension and repeat along that dimension + # [..., chunk_size] -> [..., chunk_size, chunk_size] + input_tensor = input_tensor[..., None].expand(*input_tensor.size(), chunk_size) + # 2. create a lower triangular mask with the diagonal set to 0 to 0 out elements above diag + mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=-1) + input_tensor = input_tensor.masked_fill(~mask, 0) + # 3. compute actual cumsum + tensor_segsum = torch.cumsum(input_tensor, dim=-2) + + # 4. apply mask to keep only the lower triangular part of the cumulative sum result (incl diagonal this time) + mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=0) + tensor_segsum = tensor_segsum.masked_fill(~mask, -torch.inf) + return tensor_segsum + + +is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) + + +def apply_mask_to_padding_states(hidden_states, attention_mask): + """ + Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66 + """ + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + dtype = hidden_states.dtype + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + return hidden_states + + +# Adapted from transformers.models.mamba2.modeling_mamba2.Mamba2Mixer +class FalconH1Mixer(nn.Module): + """ + FalconH1Mixer is identical to classic Mamba2 mixer classes but differs on two different things + - Users can pass custom intermediate_size through `config.mamba_d_ssm` + - The use of gated RMS normalization layer is optional + """ + + def __init__(self, config: FalconH1Config, layer_idx: int): + super().__init__() + self.num_heads = config.mamba_n_heads + self.hidden_size = config.hidden_size + self.ssm_state_size = config.mamba_d_state + self.conv_kernel_size = config.mamba_d_conv + self.intermediate_size = ( + int(config.mamba_expand * self.hidden_size) if config.mamba_d_ssm is None else config.mamba_d_ssm + ) + self.layer_idx = layer_idx + self.use_conv_bias = config.mamba_conv_bias + self.activation = config.hidden_act + self.act = ACT2FN[config.hidden_act] + self.use_bias = config.mamba_proj_bias + + self.layer_norm_epsilon = config.rms_norm_eps + self.groups_time_state_size = config.mamba_n_groups * self.ssm_state_size + + self.n_groups = config.mamba_n_groups + self.head_dim = config.mamba_d_head + self.chunk_size = config.mamba_chunk_size + + # FIXME: + self.time_step_limit = (0.0, float("inf")) + self.time_step_min = 0.001 + self.time_step_max = 0.1 + + self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + self.conv1d = nn.Conv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + bias=config.mamba_conv_bias, + kernel_size=self.conv_kernel_size, + groups=self.conv_dim, + padding=self.conv_kernel_size - 1, + ) + + # projection of the input hidden states + projection_size = self.intermediate_size + self.conv_dim + self.num_heads + self.in_proj = nn.Linear( + self.hidden_size, + projection_size, + bias=self.use_bias, + ) + # selective projection used to make dt, B and C input dependant + + # time step projection (discretization) + # instantiate once and copy inv_dt in init_weights of PretrainedModel + self.dt_bias = nn.Parameter(torch.ones(self.num_heads)) + + # S4D real initialization. These are not discretized! + # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded + A = torch.arange(1, self.num_heads + 1) + self.A_log = nn.Parameter(torch.log(A)) + self.A_log._no_weight_decay = True + self.mamba_rms_norm = config.mamba_rms_norm + + if self.mamba_rms_norm: + self.norm = FalconH1RMSNormGated( + self.intermediate_size, + eps=self.layer_norm_epsilon, + n_groups=self.n_groups, + norm_before_gate=config.mamba_norm_before_gate, + ) + self.D = nn.Parameter(torch.ones(self.num_heads)) + self.D._no_weight_decay = True + + self.out_proj = nn.Linear(self.intermediate_size, config.hidden_size, bias=config.projectors_bias) + + if not is_fast_path_available: + logger.warning_once( + "The fast path is not available because on of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`" + " is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and" + " https://github.com/Dao-AILab/causal-conv1d" + ) + else: + logger.warning_once("The fast path for FalconH1 will be used when running the model on a GPU") + + self.zxbcdt_multipliers = config.ssm_multipliers + self.ssm_in_multiplier = config.ssm_in_multiplier + + def cuda_kernels_forward( + self, + hidden_states: torch.Tensor, + cache_params: Optional[FalconHybridMambaAttentionDynamicCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + # 1. Gated MLP's linear projection + hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) + hidden_states = hidden_states * self.ssm_in_multiplier + projected_states = self.in_proj(hidden_states) + projected_states = projected_states * self.mup_vector + d_to_remove = 2 * self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + self.num_heads + + # Set up dimensions for reshapes later + batch_size, seq_len, _ = hidden_states.shape + groups_time_state_size = self.n_groups * self.ssm_state_size + + use_precomputed_states = ( + cache_params is not None + and cache_params.has_previous_state + and seq_len == 1 + and cache_params.conv_states[self.layer_idx].shape[0] + == cache_params.ssm_states[self.layer_idx].shape[0] + == batch_size + and cache_position is not None + and cache_position[0] > 0 + ) + + # getting projected states from cache if it exists + if use_precomputed_states: + d_mlp = (projected_states.squeeze(1).shape[-1] - d_to_remove) // 2 + + z0, x0, gate, hidden_states_B_C, dt = projected_states.squeeze(1).split( + [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 + ) + + # 2. Convolution sequence transformation + hidden_states_B_C = causal_conv1d_update( + hidden_states_B_C, + cache_params.conv_states[self.layer_idx], + self.conv1d.weight.squeeze(1), + self.conv1d.bias, + self.activation, + ) + + hidden_states, B, C = torch.split( + hidden_states_B_C, + [self.intermediate_size, groups_time_state_size, groups_time_state_size], + dim=-1, + ) + + # 3. SSM transformation + A = -torch.exp(self.A_log.float()) # (nheads,) + A = A[:, None, ...][:, :, None].expand(-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) + dt = dt[:, :, None].expand(-1, -1, self.head_dim) + dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim) + D = self.D[:, None, ...].expand(-1, self.head_dim) + B = B.view(batch_size, self.n_groups, B.shape[1] // self.n_groups) + C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups) + hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim) + hidden_states = selective_state_update( + cache_params.ssm_states[self.layer_idx], + hidden_states_reshaped, + dt, + A, + B, + C, + D, + z=gate.view(batch_size, self.num_heads, self.head_dim) if not self.mamba_rms_norm else None, + dt_bias=dt_bias, + dt_softplus=True, + ) + hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim) + + if self.mamba_rms_norm: + hidden_states = self.norm(hidden_states, gate) + + if d_mlp > 0: + hidden_states = torch.cat([F.silu(z0) * x0, hidden_states], dim=-1) + + # 4. Final linear projection + out = self.out_proj(hidden_states[:, None, ...]) + # Fused calculations or step by step if no initialized cache is found + else: + A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size) + dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit} + + # 2-4. Fused kernel for conv1d, SSM, and the final projection + if self.training and cache_params is None: + out = mamba_split_conv1d_scan_combined( + projected_states, + self.conv1d.weight.squeeze(1), + self.conv1d.bias, + self.dt_bias, + A, + D=self.D, + chunk_size=self.chunk_size, + seq_idx=None, # was seq_idx + activation=self.activation, + rmsnorm_weight=self.norm.weight if self.mamba_rms_norm else None, + rmsnorm_eps=self.norm.variance_epsilon if self.mamba_rms_norm else None, + outproj_weight=self.out_proj.weight, + outproj_bias=self.out_proj.bias, + headdim=self.head_dim, + ngroups=self.n_groups, + norm_before_gate=False, + return_final_states=False, + **dt_limit_kwargs, + ) + + else: + d_mlp = ( + projected_states.shape[-1] + - 2 * self.intermediate_size + - 2 * self.n_groups * self.ssm_state_size + - self.num_heads + ) // 2 + if attention_mask is not None: + projected_states = projected_states * attention_mask[..., None] + _, gate, hidden_states_B_C, dt = projected_states.split( + [ + 2 * d_mlp, + self.intermediate_size, + self.conv_dim, + self.num_heads, + ], + dim=-1, + ) + + if cache_params is not None: + conv_states = F.pad( + hidden_states_B_C.permute(0, 2, 1), + (self.conv_kernel_size - hidden_states_B_C.shape[-2], 0), + ) + cache_params.update_conv_state(self.layer_idx, conv_states, cache_position) + + time_step = nn.functional.softplus(dt + self.dt_bias) + # 1D Convolution + if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]: + hidden_states_B_C = self.act( + self.conv1d(hidden_states_B_C.transpose(1, 2)).transpose(1, 2)[:, :seq_len] + ) # (B, L, self.d_inner + 2 * ngroups * d_state) + else: + hidden_states_B_C = causal_conv1d_fn( + x=hidden_states_B_C.transpose(1, 2), + weight=self.conv1d.weight.squeeze(1), + bias=self.conv1d.bias, + activation=self.activation, + ).transpose(1, 2)[:, :seq_len] + + hidden_states, B, C = torch.split( + hidden_states_B_C, + [ + self.intermediate_size, + groups_time_state_size, + groups_time_state_size, + ], + dim=-1, + ) + + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 + dtype = hidden_states.dtype + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + # This is a hack to make sure multi-GPU inference works with HF accelerate + # see: https://github.com/Dao-AILab/flash-attention/issues/523 for more details + with torch.cuda.device(hidden_states.device): + scan_output, ssm_state = mamba_chunk_scan_combined( + hidden_states.view(batch_size, seq_len, -1, self.head_dim), + time_step, + A, + B.view(batch_size, seq_len, self.n_groups, -1), + C.view(batch_size, seq_len, self.n_groups, -1), + chunk_size=self.chunk_size, + D=self.D, + z=None, + seq_idx=None, + return_final_states=True, + **dt_limit_kwargs, + ) + if ssm_state is not None and cache_params is not None: + cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + scan_output = scan_output.view(batch_size, seq_len, -1) + # Multiply "gate" branch and apply extra normalization layer + if self.mamba_rms_norm: + out = self.norm(scan_output, gate) + else: + out = scan_output * torch.nn.functional.silu(gate) + out = self.out_proj(out) + return out + + # fmt: off + def torch_forward( + self, + input_states, + cache_params: Optional[FalconHybridMambaAttentionDynamicCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + batch_size, seq_len, _ = input_states.shape + dtype = input_states.dtype + + # 1. Gated MLP's linear projection + input_states = apply_mask_to_padding_states(input_states, attention_mask) + projected_states = self.in_proj(input_states) + gate, hidden_states_B_C, dt = projected_states.split( + [self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 + ) + + use_precomputed_states = ( + cache_params is not None + and cache_params.has_previous_state + and seq_len == 1 + and cache_params.conv_states[self.layer_idx].shape[0] + == cache_params.ssm_states[self.layer_idx].shape[0] + == batch_size + and cache_position is not None + and cache_position[0] > 0 + ) + + # 2. Convolution sequence transformation + if use_precomputed_states: + cache_params.conv_states[self.layer_idx] = cache_params.conv_states[self.layer_idx].roll(shifts=-1, dims=-1) + cache_params.conv_states[self.layer_idx][:, :, -1] = hidden_states_B_C[:, 0, :].to(cache_params.conv_states[self.layer_idx].device) + + # We need to guarantee that anything regarding the cache is on the same device + conv_states = cache_params.conv_states[self.layer_idx].to(device=self.conv1d.weight.device) + + hidden_states_B_C = torch.sum( + conv_states * self.conv1d.weight.squeeze(1), dim=-1 + ) + if self.use_conv_bias: + hidden_states_B_C = hidden_states_B_C + self.conv1d.bias + hidden_states_B_C = self.act(hidden_states_B_C) + else: + # Init cache + if cache_params is not None: + hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2) + conv_states = nn.functional.pad( + hidden_states_B_C_transposed, (self.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0) + ) + cache_params.conv_states[self.layer_idx].copy_(conv_states) + + hidden_states_B_C = self.act(self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)) + + hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask) + hidden_states, B, C = torch.split( + hidden_states_B_C, + [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size], + dim=-1 + ) + + # 3. SSM transformation + A = -torch.exp(self.A_log.float()) # [num_heads] + if use_precomputed_states: + # We need to guarantee that anything regarding the cache is on the same device + cache_device = cache_params.ssm_states[self.layer_idx].device + + # Note: there is no need to pad parameter matrices here, as there is just one new token + # for batched generation + dt = dt[:, 0, :][:, None, ...] + dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim) + # [num_heads] -> [num_heads, head_dim] + dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim) + + dt = torch.nn.functional.softplus(dt + dt_bias.to(dt.dtype)) + dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1]) + A = A[..., None, None].expand(self.num_heads, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) + # [bsz, num_heads, head_dim, state_size] + dA = (torch.exp(dt[..., None] * A)).to(device=cache_device) + + # Discretize B + # [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] -> + # -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size] + B = B.reshape(batch_size, self.n_groups, -1)[..., None, :] + B = B.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, B.shape[-1]).contiguous() + B = B.reshape(batch_size, -1, B.shape[-1]) + # [bsz, num_heads, head_dim, state_size] + dB = dt[..., None] * B[..., None, :] + + # Discretize x into dB + # [bsz, intermediate_size] -> [bsz, num_heads, head_dim] + hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim) + dBx = (dB * hidden_states[..., None]).to(device=cache_device) + + # State calculation + cache_params.ssm_states[self.layer_idx].copy_( + cache_params.ssm_states[self.layer_idx] * dA + dBx + ) + + # Subsequent output + # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] + C = C.reshape(batch_size, self.n_groups, -1)[..., None, :] + C = C.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, C.shape[-1]).contiguous() + C = C.reshape(batch_size, -1, C.shape[-1]) + # [bsz, num_heads, head_dim] + + ssm_states = cache_params.ssm_states[self.layer_idx].to(device=C.device, dtype=C.dtype) # Shape: [b, h, d, n] + # Reshape ssm_states to merge the first two dimensions + ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n] + C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1] + y = torch.bmm(ssm_states_reshaped, C_reshaped) + y = y.view(batch_size, self.num_heads, self.head_dim) + + # D skip connection + # [num_heads] -> [num_heads, head_dim] + D = self.D[..., None].expand(self.D.shape[0], self.head_dim) + y = (y + hidden_states * D).to(y.dtype) + + # [bsz, num_heads, head_dim] -> [bsz, 1, intermediate_size] + y = y.reshape(batch_size, -1)[:, None, ...] + else: + # begin ssd naive implementation without einsums + dt = nn.functional.softplus(dt + self.dt_bias) + dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1]) + hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float() + B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() + C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() + B = B.repeat(1, 1, self.num_heads // self.n_groups, 1) + C = C.repeat(1, 1, self.num_heads // self.n_groups, 1) + pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size + + D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size) + + # Discretize x and A + hidden_states = hidden_states * dt[..., None] + A = A.to(hidden_states.dtype) * dt + + # Rearrange into blocks/chunks + hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)] + + # [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size] + A = A.permute(0, 3, 1, 2) + A_cumsum = torch.cumsum(A, dim=-1) + + # 1. Compute the output for each intra-chunk (diagonal blocks) + # This is the analog of a causal mask + L = torch.exp(segment_sum(A)) + + # Contraction of C and B to get G (attention-weights like) + G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, :, :] # shape: (b, c, l, s, h, n) + G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h) + + # Compute M, equivalent to applying attention mask to weights + M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None] + M = M_intermediate.sum(dim=-1) + + # Compute Y_diag (apply to values) + Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(dim=3) + + # 2. Compute the state for each intra-chunk + # (right term of low-rank factorization of off-diagonal blocks; B terms) + decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum)) + B_decay = B * decay_states.permute(0, -2, -1, 1)[..., None] + states = (B_decay[..., None, :] * hidden_states[..., None]).sum(dim=2) + + # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries + # (middle term of factorization of off-diag blocks; A terms) + if use_precomputed_states: + previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...].to(device=states.device) + else: + previous_states = torch.zeros_like(states[:, :1]) + states = torch.cat([previous_states, states], dim=1) + decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0)))) + decay_chunk = decay_chunk.transpose(1, 3) + new_states = (decay_chunk[..., None, None] * states[:, :, None, ...]).sum(dim=1) + states, ssm_state = new_states[:, :-1], new_states[:, -1] + + # 4. Compute state -> output conversion per chunk + # (left term of low-rank factorization of off-diagonal blocks; C terms) + state_decay_out = torch.exp(A_cumsum) + C_times_states = (C[..., None, :] * states[:, :, None, ...]) + state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1) + Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None]) + + # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks) + y = Y_diag + Y_off + # [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim] + y = y.reshape(batch_size, -1, self.num_heads, self.head_dim) + + y = y + D_residual + # Cutting off padded chunks + if pad_size > 0: + y = y[:, :seq_len, :, :] + y = y.reshape(batch_size, seq_len, -1) + + # Init cache + if ssm_state is not None and cache_params is not None: + cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + + if self.mamba_rms_norm: + scan_output = self.norm(y, gate) + else: + scan_output = y * torch.nn.functional.silu(gate) + + # end ssd naive + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_output.to(dtype)) # [batch, seq_len, hidden_size] + return contextualized_states + # fmt: on + + def forward( + self, + hidden_states, + cache_params: Optional[FalconHybridMambaAttentionDynamicCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + if is_fast_path_available and "cuda" in self.in_proj.weight.device.type: + return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask) + dtype = hidden_states.dtype + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask) + + +class FalconH1MLP(nn.Module): + def __init__(self, config: FalconH1Config = None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + self.gate_multiplier, self.down_multiplier = config.mlp_multipliers + + def forward(self, x): + y = self.up_proj(x) * self.act_fn(self.gate_proj(x) * self.gate_multiplier) + y = self.down_proj(y) * self.down_multiplier + return y + + +@use_kernel_forward_from_hub("RMSNorm") +class FalconH1RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + FalconH1RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class FalconH1DecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: FalconH1Config, layer_idx: int): + super().__init__() + self.feed_forward = FalconH1MLP(config) + + head_dim = config.hidden_size // config.num_attention_heads + self.channels_attn = config.num_attention_heads * head_dim + 2 * config.num_key_value_heads * head_dim + + self.mamba = FalconH1Mixer(config=config, layer_idx=layer_idx) + + self.self_attn = FalconH1Attention(config, layer_idx) + + self.attention_in_multiplier = config.attention_in_multiplier + self.ssm_out_multiplier = config.ssm_out_multiplier + self.attn_out_multiplier = config.attention_out_multiplier + + self.input_layernorm = FalconH1RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.pre_ff_layernorm = FalconH1RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + mamba_attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[FalconHybridMambaAttentionDynamicCache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + past_key_value (`FalconHybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + mamba_hidden_states = self.mamba( + hidden_states=hidden_states, + cache_params=past_key_value, + cache_position=cache_position, + attention_mask=mamba_attention_mask, + ) + mamba_hidden_states = mamba_hidden_states * self.ssm_out_multiplier + + attention_hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states * self.attention_in_multiplier, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + attention_hidden_states = attention_hidden_states * self.attn_out_multiplier + + hidden_states = mamba_hidden_states + attention_hidden_states + + # residual connection after attention + hidden_states = residual + hidden_states + + # feed-forward + residual = hidden_states + hidden_states = self.pre_ff_layernorm(hidden_states) + hidden_states = self.feed_forward(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +@auto_docstring +class FalconH1PreTrainedModel(PreTrainedModel): + config_class = FalconH1Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["FalconH1DecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True # Note: only supports FalconHybridMambaAttentionDynamicCache + _is_stateful = True + + def _init_weights(self, module): + std = self.config.initializer_range + for name, param in module.named_parameters(recurse=True): + if not param.requires_grad: + continue + if "layernorm" in name.lower() and "weight" in name: + # LayerNorm weights usually initialized to 1 + param.data.fill_(1.0) + elif "bias" in name: + param.data.zero_() + else: + try: + param.data.normal_(mean=0.0, std=std) + except Exception as e: + print(f"Skipping init for {name} due to error: {e}") + + +def compute_mup_vector(config): + """ + Computes the MuP vector based on model configuration. + + FalconH1 applies different MuP multiplier for each dimension of the hidden states. + The MuP vector is partitioned into chunks, and each chunk is multiplied with its + corresponding projected dimension. + + Args: + config: FalconH1Config object + + Returns: + torch.Tensor: The computed MuP vector + """ + # We'll need some values from the config to compute the vector dimensions + intermediate_size = ( + config.mamba_d_ssm if config.mamba_d_ssm is not None else int(config.mamba_expand * config.hidden_size) + ) + groups_time_state_size = config.mamba_n_groups * config.mamba_d_state + num_heads = config.mamba_n_heads + zxbcdt_multipliers = config.ssm_multipliers + + vector_shape = 2 * intermediate_size + 2 * groups_time_state_size + num_heads + mup_vector = torch.ones(1, 1, vector_shape) + + # Apply multipliers to different sections of the vector + mup_vector[:, :, :intermediate_size] *= zxbcdt_multipliers[0] + mup_vector[:, :, intermediate_size : 2 * intermediate_size] *= zxbcdt_multipliers[1] + mup_vector[:, :, 2 * intermediate_size : 2 * intermediate_size + groups_time_state_size] *= zxbcdt_multipliers[2] + mup_vector[ + :, :, 2 * intermediate_size + groups_time_state_size : 2 * intermediate_size + 2 * groups_time_state_size + ] *= zxbcdt_multipliers[3] + mup_vector[:, :, 2 * intermediate_size + 2 * groups_time_state_size :] *= zxbcdt_multipliers[4] + + return mup_vector + + +@auto_docstring +# Adapted from transformers.models.jamba.modeling_jamba.JambaModel +class FalconH1Model(FalconH1PreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`FalconH1DecoderLayer`] + + Args: + config: FalconH1Config + """ + + def __init__(self, config: FalconH1Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + decoder_layers = [] + for i in range(config.num_hidden_layers): + decoder_layers.append(FalconH1DecoderLayer(config, layer_idx=i)) + self.layers = nn.ModuleList(decoder_layers) + + self._attn_implementation = config._attn_implementation + self.final_layernorm = FalconH1RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = FalconH1RotaryEmbedding(config=config) + + self.embedding_multiplier = config.embedding_multiplier + self.lm_head_multiplier = config.lm_head_multiplier + + self.gradient_checkpointing = False + # Compute the MuP vector once and register it for all layers + mup_vector = compute_mup_vector(config) + for layer in self.layers: + layer.mamba.register_buffer("mup_vector", mup_vector, persistent=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[FalconHybridMambaAttentionDynamicCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, # NOOP kwargs, for now + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) * self.embedding_multiplier + hidden_states = inputs_embeds + + if use_cache and past_key_values is None: + logger.warning_once( + "FalconH1 requires an initialized `FalconHybridMambaAttentionDynamicCache` to return a cache. None was " + "provided, so no cache will be returned." + ) + + if cache_position is None: + cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + mamba_mask = self._update_mamba_mask(attention_mask, cache_position) + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + mamba_attention_mask=mamba_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + if layer_outputs[1] is not None: + # append attentions only of attention layers. Mamba layers return `None` as the attention weights + all_self_attns += (layer_outputs[1],) + + hidden_states = self.final_layernorm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if past_key_values and not past_key_values.has_previous_state: + past_key_values.has_previous_state = True + + next_cache = None if not use_cache else past_key_values + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_mamba_mask(self, attention_mask, cache_position): + """ + No need for zeroing states when + 1. Cached forward + 2. Attending to all inputs + """ + mamba_mask = attention_mask + if cache_position[0] > 0 or (attention_mask is not None and torch.all(attention_mask == 1)): + mamba_mask = None + return mamba_mask + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: FalconHybridMambaAttentionDynamicCache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype = input_tensor.dtype + sequence_length = input_tensor.shape[1] + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_attention_mask = (attention_mask[:, None, None, :] == attention_mask[:, None, :, None])[ + :, :, -sequence_length:, : + ].to(dtype) + padding_mask = causal_mask[:, :, :, :mask_length] + padding_attention_mask + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +@auto_docstring +class FalconH1ForCausalLM(FalconH1PreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = FalconH1Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + @auto_docstring + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[FalconHybridMambaAttentionDynamicCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, FalconH1ForCausalLM + + >>> model = FalconH1ForCausalLM.from_pretrained("...") + >>> tokenizer = AutoTokenizer.from_pretrained("...") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) * self.model.lm_head_multiplier + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + **kwargs, + ): + # Overwitten -- has a unique cache type, `FalconHybridMambaAttentionDynamicCache` + + empty_past_kv = past_key_values is None + + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. + # (we can't check exception 3 while compiling) + if not empty_past_kv: + if ( + inputs_embeds is not None # Exception 1 + or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3 + ): + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + else: + past_key_values = FalconHybridMambaAttentionDynamicCache( + self.config, + input_ids.shape[0], + self.dtype, + devices=[ + self.model.layers[i].mamba.conv1d.weight.device for i in range(self.config.num_hidden_layers) + ], + ) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if not empty_past_kv: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and empty_past_kv: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "logits_to_keep": self.config.num_logits_to_keep, + "cache_position": cache_position, + } + ) + return model_inputs + + +__all__ = ["FalconH1Model", "FalconH1ForCausalLM", "FalconH1PreTrainedModel"] diff --git a/src/transformers/models/falcon_h1/modular_falcon_h1.py b/src/transformers/models/falcon_h1/modular_falcon_h1.py new file mode 100644 index 00000000000..07b9e540848 --- /dev/null +++ b/src/transformers/models/falcon_h1/modular_falcon_h1.py @@ -0,0 +1,1442 @@ +# coding=utf-8 +# Copyright 2025 Technology Innovation Institute and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch FalconH1 model.""" + +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn + +from transformers.activations import ACT2FN +from transformers.models.jamba.modeling_jamba import HybridMambaAttentionDynamicCache +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaForCausalLM, + LlamaMLP, + LlamaRMSNorm, + LlamaRotaryEmbedding, + apply_rotary_pos_emb, + eager_attention_forward, +) +from transformers.models.mamba2.modeling_mamba2 import ( + MambaRMSNormGated, + pad_tensor_by_size, + reshape_into_chunks, + segment_sum, +) + +from ...cache_utils import Cache +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ( + auto_docstring, + is_torchdynamo_compiling, + logging, + replace_return_docstrings, +) +from ...utils.deprecation import deprecate_kwarg +from ...utils.import_utils import ( + is_causal_conv1d_available, + is_flash_attn_2_available, + is_mamba_2_ssm_available, +) +from .configuration_falcon_h1 import FalconH1Config + + +if is_flash_attn_2_available(): + pass + +if is_mamba_2_ssm_available(): + from mamba_ssm.ops.triton.selective_state_update import selective_state_update + from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined +else: + selective_state_update = None + +if is_causal_conv1d_available(): + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +else: + causal_conv1d_update, causal_conv1d_fn = None, None + +is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "FalconH1Config" + + +class FalconHybridMambaAttentionDynamicCache(HybridMambaAttentionDynamicCache): + """ + A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache + (which has a constant shape regardless of seq_len). + + This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` + and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor + For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, + while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). + For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), + while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, + and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. + """ + + def __init__( + self, + config: FalconH1Config, + batch_size: int, + dtype: torch.dtype = torch.float16, + devices: Optional[List[str]] = None, + ): + self.seqlen_offset = 0 + self.dtype = dtype + self.has_previous_state = False + self.conv_kernel_size = config.mamba_d_conv + + self._seen_tokens = 0 + + self.intermediate_size = ( + config.mamba_d_ssm if config.mamba_d_ssm is not None else int(config.mamba_expand * config.hidden_size) + ) + + self.conv_states = { + i: torch.zeros( + batch_size, + self.intermediate_size + 2 * config.mamba_n_groups * config.mamba_d_state, + self.conv_kernel_size, + device=devices[i], + dtype=dtype, + ) + for i in range(config.num_hidden_layers) + } + self.ssm_states = { + i: torch.zeros( + batch_size, + config.mamba_n_heads, + config.mamba_d_head, + config.mamba_d_state, + device=devices[i], + dtype=dtype, + ) + for i in range(config.num_hidden_layers) + } + + self.transformer_layers = [] + for i in range(config.num_hidden_layers): + self.transformer_layers.append(i) + + self.key_cache: List[torch.Tensor] = [] + self.value_cache: List[torch.Tensor] = [] + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. + + Return: + A tuple containing the updated key and value states. + """ + # Update the number of seen tokens + if layer_idx == 0: + self._seen_tokens += key_states.shape[-2] + + # Update the cache + if len(self.key_cache) <= layer_idx: + # There may be skipped layers, fill them with empty lists + for _ in range(len(self.key_cache), layer_idx): + self.key_cache.append([]) + self.value_cache.append([]) + self.key_cache.append(key_states) + self.value_cache.append(value_states) + elif len(self.key_cache[layer_idx]) == 0: # fills previously skipped layers; checking for tensor causes errors + self.key_cache[layer_idx] = key_states + self.value_cache[layer_idx] = value_states + else: + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def update_conv_state( + self, + layer_idx: int, + new_conv_state: torch.Tensor, + cache_position: torch.LongTensor, + ) -> torch.Tensor: + conv_state = self.conv_states[layer_idx] + cache_position = cache_position.clamp(0, self.conv_kernel_size - 1) + + conv_state = conv_state.roll(shifts=-1, dims=-1) + if len(cache_position) > 1: + conv_state[:, :, :] = new_conv_state.to(conv_state.device) + else: + conv_state[:, :, -1] = new_conv_state[:, :, -1].to(conv_state.device) + self.conv_states[layer_idx].zero_() + self.conv_states[layer_idx] += conv_state + return self.conv_states[layer_idx] + + def reset(self): + self.conv_states.zero_() + self.ssm_states.zero_() + + +class FalconH1RotaryEmbedding(LlamaRotaryEmbedding): + pass + + +class FalconH1Attention(LlamaAttention): + def __init__(self, config: FalconH1Config, layer_idx: int): + super().__init__(config, layer_idx) + self.key_multiplier = config.key_multiplier + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) * self.key_multiplier + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class FalconH1RMSNormGated(MambaRMSNormGated): + def __init__(self, hidden_size, eps=1e-6, n_groups=1, norm_before_gate=True): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + self.n_groups = n_groups + self.norm_before_gate = norm_before_gate + + def forward(self, hidden_states, gate=None): + input_dtype = hidden_states.dtype + + if not self.norm_before_gate and gate is not None: + hidden_states = hidden_states * F.silu(gate.to(torch.float32)) + + if len(hidden_states.shape) == 3: + batch_size, seq_len, dim = hidden_states.shape + else: + batch_size, dim = hidden_states.shape + seq_len = 1 + hidden_states = hidden_states.to(torch.float32) + + hidden_states = hidden_states.view(batch_size, seq_len, self.n_groups, int(dim // self.n_groups)) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + hidden_states = self.weight.view(self.n_groups, int(dim // self.n_groups)) * hidden_states + hidden_states = hidden_states.view(batch_size, seq_len, dim) + + if seq_len == 1: + hidden_states = hidden_states.squeeze(1) + + if self.norm_before_gate and gate is not None: + hidden_states = hidden_states * F.silu(gate.to(torch.float32)) + return hidden_states.to(input_dtype) + + +def apply_mask_to_padding_states(hidden_states, attention_mask): + """ + Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66 + """ + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + dtype = hidden_states.dtype + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + return hidden_states + + +# Adapted from transformers.models.mamba2.modeling_mamba2.Mamba2Mixer +class FalconH1Mixer(nn.Module): + """ + FalconH1Mixer is identical to classic Mamba2 mixer classes but differs on two different things + - Users can pass custom intermediate_size through `config.mamba_d_ssm` + - The use of gated RMS normalization layer is optional + """ + + def __init__(self, config: FalconH1Config, layer_idx: int): + super().__init__() + self.num_heads = config.mamba_n_heads + self.hidden_size = config.hidden_size + self.ssm_state_size = config.mamba_d_state + self.conv_kernel_size = config.mamba_d_conv + self.intermediate_size = ( + int(config.mamba_expand * self.hidden_size) if config.mamba_d_ssm is None else config.mamba_d_ssm + ) + self.layer_idx = layer_idx + self.use_conv_bias = config.mamba_conv_bias + self.activation = config.hidden_act + self.act = ACT2FN[config.hidden_act] + self.use_bias = config.mamba_proj_bias + + self.layer_norm_epsilon = config.rms_norm_eps + self.groups_time_state_size = config.mamba_n_groups * self.ssm_state_size + + self.n_groups = config.mamba_n_groups + self.head_dim = config.mamba_d_head + self.chunk_size = config.mamba_chunk_size + + # FIXME: + self.time_step_limit = (0.0, float("inf")) + self.time_step_min = 0.001 + self.time_step_max = 0.1 + + self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + self.conv1d = nn.Conv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + bias=config.mamba_conv_bias, + kernel_size=self.conv_kernel_size, + groups=self.conv_dim, + padding=self.conv_kernel_size - 1, + ) + + # projection of the input hidden states + projection_size = self.intermediate_size + self.conv_dim + self.num_heads + self.in_proj = nn.Linear( + self.hidden_size, + projection_size, + bias=self.use_bias, + ) + # selective projection used to make dt, B and C input dependant + + # time step projection (discretization) + # instantiate once and copy inv_dt in init_weights of PretrainedModel + self.dt_bias = nn.Parameter(torch.ones(self.num_heads)) + + # S4D real initialization. These are not discretized! + # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded + A = torch.arange(1, self.num_heads + 1) + self.A_log = nn.Parameter(torch.log(A)) + self.A_log._no_weight_decay = True + self.mamba_rms_norm = config.mamba_rms_norm + + if self.mamba_rms_norm: + self.norm = FalconH1RMSNormGated( + self.intermediate_size, + eps=self.layer_norm_epsilon, + n_groups=self.n_groups, + norm_before_gate=config.mamba_norm_before_gate, + ) + self.D = nn.Parameter(torch.ones(self.num_heads)) + self.D._no_weight_decay = True + + self.out_proj = nn.Linear(self.intermediate_size, config.hidden_size, bias=config.projectors_bias) + + if not is_fast_path_available: + logger.warning_once( + "The fast path is not available because on of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`" + " is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and" + " https://github.com/Dao-AILab/causal-conv1d" + ) + else: + logger.warning_once("The fast path for FalconH1 will be used when running the model on a GPU") + + self.zxbcdt_multipliers = config.ssm_multipliers + self.ssm_in_multiplier = config.ssm_in_multiplier + + def cuda_kernels_forward( + self, + hidden_states: torch.Tensor, + cache_params: Optional[FalconHybridMambaAttentionDynamicCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + # 1. Gated MLP's linear projection + hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) + hidden_states = hidden_states * self.ssm_in_multiplier + projected_states = self.in_proj(hidden_states) + projected_states = projected_states * self.mup_vector + d_to_remove = 2 * self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + self.num_heads + + # Set up dimensions for reshapes later + batch_size, seq_len, _ = hidden_states.shape + groups_time_state_size = self.n_groups * self.ssm_state_size + + use_precomputed_states = ( + cache_params is not None + and cache_params.has_previous_state + and seq_len == 1 + and cache_params.conv_states[self.layer_idx].shape[0] + == cache_params.ssm_states[self.layer_idx].shape[0] + == batch_size + and cache_position is not None + and cache_position[0] > 0 + ) + + # getting projected states from cache if it exists + if use_precomputed_states: + d_mlp = (projected_states.squeeze(1).shape[-1] - d_to_remove) // 2 + + z0, x0, gate, hidden_states_B_C, dt = projected_states.squeeze(1).split( + [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 + ) + + # 2. Convolution sequence transformation + hidden_states_B_C = causal_conv1d_update( + hidden_states_B_C, + cache_params.conv_states[self.layer_idx], + self.conv1d.weight.squeeze(1), + self.conv1d.bias, + self.activation, + ) + + hidden_states, B, C = torch.split( + hidden_states_B_C, + [self.intermediate_size, groups_time_state_size, groups_time_state_size], + dim=-1, + ) + + # 3. SSM transformation + A = -torch.exp(self.A_log.float()) # (nheads,) + A = A[:, None, ...][:, :, None].expand(-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) + dt = dt[:, :, None].expand(-1, -1, self.head_dim) + dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim) + D = self.D[:, None, ...].expand(-1, self.head_dim) + B = B.view(batch_size, self.n_groups, B.shape[1] // self.n_groups) + C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups) + hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim) + hidden_states = selective_state_update( + cache_params.ssm_states[self.layer_idx], + hidden_states_reshaped, + dt, + A, + B, + C, + D, + z=gate.view(batch_size, self.num_heads, self.head_dim) if not self.mamba_rms_norm else None, + dt_bias=dt_bias, + dt_softplus=True, + ) + hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim) + + if self.mamba_rms_norm: + hidden_states = self.norm(hidden_states, gate) + + if d_mlp > 0: + hidden_states = torch.cat([F.silu(z0) * x0, hidden_states], dim=-1) + + # 4. Final linear projection + out = self.out_proj(hidden_states[:, None, ...]) + # Fused calculations or step by step if no initialized cache is found + else: + A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size) + dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit} + + # 2-4. Fused kernel for conv1d, SSM, and the final projection + if self.training and cache_params is None: + out = mamba_split_conv1d_scan_combined( + projected_states, + self.conv1d.weight.squeeze(1), + self.conv1d.bias, + self.dt_bias, + A, + D=self.D, + chunk_size=self.chunk_size, + seq_idx=None, # was seq_idx + activation=self.activation, + rmsnorm_weight=self.norm.weight if self.mamba_rms_norm else None, + rmsnorm_eps=self.norm.variance_epsilon if self.mamba_rms_norm else None, + outproj_weight=self.out_proj.weight, + outproj_bias=self.out_proj.bias, + headdim=self.head_dim, + ngroups=self.n_groups, + norm_before_gate=False, + return_final_states=False, + **dt_limit_kwargs, + ) + + else: + d_mlp = ( + projected_states.shape[-1] + - 2 * self.intermediate_size + - 2 * self.n_groups * self.ssm_state_size + - self.num_heads + ) // 2 + if attention_mask is not None: + projected_states = projected_states * attention_mask[..., None] + _, gate, hidden_states_B_C, dt = projected_states.split( + [ + 2 * d_mlp, + self.intermediate_size, + self.conv_dim, + self.num_heads, + ], + dim=-1, + ) + + if cache_params is not None: + conv_states = F.pad( + hidden_states_B_C.permute(0, 2, 1), + (self.conv_kernel_size - hidden_states_B_C.shape[-2], 0), + ) + cache_params.update_conv_state(self.layer_idx, conv_states, cache_position) + + time_step = nn.functional.softplus(dt + self.dt_bias) + # 1D Convolution + if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]: + hidden_states_B_C = self.act( + self.conv1d(hidden_states_B_C.transpose(1, 2)).transpose(1, 2)[:, :seq_len] + ) # (B, L, self.d_inner + 2 * ngroups * d_state) + else: + hidden_states_B_C = causal_conv1d_fn( + x=hidden_states_B_C.transpose(1, 2), + weight=self.conv1d.weight.squeeze(1), + bias=self.conv1d.bias, + activation=self.activation, + ).transpose(1, 2)[:, :seq_len] + + hidden_states, B, C = torch.split( + hidden_states_B_C, + [ + self.intermediate_size, + groups_time_state_size, + groups_time_state_size, + ], + dim=-1, + ) + + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 + dtype = hidden_states.dtype + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + # This is a hack to make sure multi-GPU inference works with HF accelerate + # see: https://github.com/Dao-AILab/flash-attention/issues/523 for more details + with torch.cuda.device(hidden_states.device): + scan_output, ssm_state = mamba_chunk_scan_combined( + hidden_states.view(batch_size, seq_len, -1, self.head_dim), + time_step, + A, + B.view(batch_size, seq_len, self.n_groups, -1), + C.view(batch_size, seq_len, self.n_groups, -1), + chunk_size=self.chunk_size, + D=self.D, + z=None, + seq_idx=None, + return_final_states=True, + **dt_limit_kwargs, + ) + if ssm_state is not None and cache_params is not None: + cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + scan_output = scan_output.view(batch_size, seq_len, -1) + # Multiply "gate" branch and apply extra normalization layer + if self.mamba_rms_norm: + out = self.norm(scan_output, gate) + else: + out = scan_output * torch.nn.functional.silu(gate) + out = self.out_proj(out) + return out + + # fmt: off + def torch_forward( + self, + input_states, + cache_params: Optional[FalconHybridMambaAttentionDynamicCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + batch_size, seq_len, _ = input_states.shape + dtype = input_states.dtype + + # 1. Gated MLP's linear projection + input_states = apply_mask_to_padding_states(input_states, attention_mask) + projected_states = self.in_proj(input_states) + gate, hidden_states_B_C, dt = projected_states.split( + [self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 + ) + + use_precomputed_states = ( + cache_params is not None + and cache_params.has_previous_state + and seq_len == 1 + and cache_params.conv_states[self.layer_idx].shape[0] + == cache_params.ssm_states[self.layer_idx].shape[0] + == batch_size + and cache_position is not None + and cache_position[0] > 0 + ) + + # 2. Convolution sequence transformation + if use_precomputed_states: + cache_params.conv_states[self.layer_idx] = cache_params.conv_states[self.layer_idx].roll(shifts=-1, dims=-1) + cache_params.conv_states[self.layer_idx][:, :, -1] = hidden_states_B_C[:, 0, :].to(cache_params.conv_states[self.layer_idx].device) + + # We need to guarantee that anything regarding the cache is on the same device + conv_states = cache_params.conv_states[self.layer_idx].to(device=self.conv1d.weight.device) + + hidden_states_B_C = torch.sum( + conv_states * self.conv1d.weight.squeeze(1), dim=-1 + ) + if self.use_conv_bias: + hidden_states_B_C = hidden_states_B_C + self.conv1d.bias + hidden_states_B_C = self.act(hidden_states_B_C) + else: + # Init cache + if cache_params is not None: + hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2) + conv_states = nn.functional.pad( + hidden_states_B_C_transposed, (self.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0) + ) + cache_params.conv_states[self.layer_idx].copy_(conv_states) + + hidden_states_B_C = self.act(self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)) + + hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask) + hidden_states, B, C = torch.split( + hidden_states_B_C, + [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size], + dim=-1 + ) + + # 3. SSM transformation + A = -torch.exp(self.A_log.float()) # [num_heads] + if use_precomputed_states: + # We need to guarantee that anything regarding the cache is on the same device + cache_device = cache_params.ssm_states[self.layer_idx].device + + # Note: there is no need to pad parameter matrices here, as there is just one new token + # for batched generation + dt = dt[:, 0, :][:, None, ...] + dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim) + # [num_heads] -> [num_heads, head_dim] + dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim) + + dt = torch.nn.functional.softplus(dt + dt_bias.to(dt.dtype)) + dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1]) + A = A[..., None, None].expand(self.num_heads, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) + # [bsz, num_heads, head_dim, state_size] + dA = (torch.exp(dt[..., None] * A)).to(device=cache_device) + + # Discretize B + # [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] -> + # -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size] + B = B.reshape(batch_size, self.n_groups, -1)[..., None, :] + B = B.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, B.shape[-1]).contiguous() + B = B.reshape(batch_size, -1, B.shape[-1]) + # [bsz, num_heads, head_dim, state_size] + dB = dt[..., None] * B[..., None, :] + + # Discretize x into dB + # [bsz, intermediate_size] -> [bsz, num_heads, head_dim] + hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim) + dBx = (dB * hidden_states[..., None]).to(device=cache_device) + + # State calculation + cache_params.ssm_states[self.layer_idx].copy_( + cache_params.ssm_states[self.layer_idx] * dA + dBx + ) + + # Subsequent output + # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] + C = C.reshape(batch_size, self.n_groups, -1)[..., None, :] + C = C.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, C.shape[-1]).contiguous() + C = C.reshape(batch_size, -1, C.shape[-1]) + # [bsz, num_heads, head_dim] + + ssm_states = cache_params.ssm_states[self.layer_idx].to(device=C.device, dtype=C.dtype) # Shape: [b, h, d, n] + # Reshape ssm_states to merge the first two dimensions + ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n] + C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1] + y = torch.bmm(ssm_states_reshaped, C_reshaped) + y = y.view(batch_size, self.num_heads, self.head_dim) + + # D skip connection + # [num_heads] -> [num_heads, head_dim] + D = self.D[..., None].expand(self.D.shape[0], self.head_dim) + y = (y + hidden_states * D).to(y.dtype) + + # [bsz, num_heads, head_dim] -> [bsz, 1, intermediate_size] + y = y.reshape(batch_size, -1)[:, None, ...] + else: + # begin ssd naive implementation without einsums + dt = nn.functional.softplus(dt + self.dt_bias) + dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1]) + hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float() + B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() + C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() + B = B.repeat(1, 1, self.num_heads // self.n_groups, 1) + C = C.repeat(1, 1, self.num_heads // self.n_groups, 1) + pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size + + D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size) + + # Discretize x and A + hidden_states = hidden_states * dt[..., None] + A = A.to(hidden_states.dtype) * dt + + # Rearrange into blocks/chunks + hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)] + + # [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size] + A = A.permute(0, 3, 1, 2) + A_cumsum = torch.cumsum(A, dim=-1) + + # 1. Compute the output for each intra-chunk (diagonal blocks) + # This is the analog of a causal mask + L = torch.exp(segment_sum(A)) + + # Contraction of C and B to get G (attention-weights like) + G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, :, :] # shape: (b, c, l, s, h, n) + G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h) + + # Compute M, equivalent to applying attention mask to weights + M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None] + M = M_intermediate.sum(dim=-1) + + # Compute Y_diag (apply to values) + Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(dim=3) + + # 2. Compute the state for each intra-chunk + # (right term of low-rank factorization of off-diagonal blocks; B terms) + decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum)) + B_decay = B * decay_states.permute(0, -2, -1, 1)[..., None] + states = (B_decay[..., None, :] * hidden_states[..., None]).sum(dim=2) + + # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries + # (middle term of factorization of off-diag blocks; A terms) + if use_precomputed_states: + previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...].to(device=states.device) + else: + previous_states = torch.zeros_like(states[:, :1]) + states = torch.cat([previous_states, states], dim=1) + decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0)))) + decay_chunk = decay_chunk.transpose(1, 3) + new_states = (decay_chunk[..., None, None] * states[:, :, None, ...]).sum(dim=1) + states, ssm_state = new_states[:, :-1], new_states[:, -1] + + # 4. Compute state -> output conversion per chunk + # (left term of low-rank factorization of off-diagonal blocks; C terms) + state_decay_out = torch.exp(A_cumsum) + C_times_states = (C[..., None, :] * states[:, :, None, ...]) + state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1) + Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None]) + + # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks) + y = Y_diag + Y_off + # [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim] + y = y.reshape(batch_size, -1, self.num_heads, self.head_dim) + + y = y + D_residual + # Cutting off padded chunks + if pad_size > 0: + y = y[:, :seq_len, :, :] + y = y.reshape(batch_size, seq_len, -1) + + # Init cache + if ssm_state is not None and cache_params is not None: + cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + + if self.mamba_rms_norm: + scan_output = self.norm(y, gate) + else: + scan_output = y * torch.nn.functional.silu(gate) + + # end ssd naive + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_output.to(dtype)) # [batch, seq_len, hidden_size] + return contextualized_states + # fmt: on + + def forward( + self, + hidden_states, + cache_params: Optional[FalconHybridMambaAttentionDynamicCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + if is_fast_path_available and "cuda" in self.in_proj.weight.device.type: + return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask) + dtype = hidden_states.dtype + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask) + + +class FalconH1MLP(LlamaMLP): + def __init__(self, config: FalconH1Config = None): + super().__init__() + self.gate_multiplier, self.down_multiplier = config.mlp_multipliers + + def forward(self, x): + y = self.up_proj(x) * self.act_fn(self.gate_proj(x) * self.gate_multiplier) + y = self.down_proj(y) * self.down_multiplier + return y + + +class FalconH1RMSNorm(LlamaRMSNorm): + pass + + +class FalconH1DecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: FalconH1Config, layer_idx: int): + super().__init__() + self.feed_forward = FalconH1MLP(config) + + head_dim = config.hidden_size // config.num_attention_heads + self.channels_attn = config.num_attention_heads * head_dim + 2 * config.num_key_value_heads * head_dim + + self.mamba = FalconH1Mixer(config=config, layer_idx=layer_idx) + + self.self_attn = FalconH1Attention(config, layer_idx) + + self.attention_in_multiplier = config.attention_in_multiplier + self.ssm_out_multiplier = config.ssm_out_multiplier + self.attn_out_multiplier = config.attention_out_multiplier + + self.input_layernorm = FalconH1RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.pre_ff_layernorm = FalconH1RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + mamba_attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[FalconHybridMambaAttentionDynamicCache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + past_key_value (`FalconHybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + mamba_hidden_states = self.mamba( + hidden_states=hidden_states, + cache_params=past_key_value, + cache_position=cache_position, + attention_mask=mamba_attention_mask, + ) + mamba_hidden_states = mamba_hidden_states * self.ssm_out_multiplier + + attention_hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states * self.attention_in_multiplier, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + attention_hidden_states = attention_hidden_states * self.attn_out_multiplier + + hidden_states = mamba_hidden_states + attention_hidden_states + + # residual connection after attention + hidden_states = residual + hidden_states + + # feed-forward + residual = hidden_states + hidden_states = self.pre_ff_layernorm(hidden_states) + hidden_states = self.feed_forward(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +@auto_docstring +class FalconH1PreTrainedModel(PreTrainedModel): + config_class = FalconH1Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["FalconH1DecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True # Note: only supports FalconHybridMambaAttentionDynamicCache + _is_stateful = True + + def _init_weights(self, module): + std = self.config.initializer_range + for name, param in module.named_parameters(recurse=True): + if not param.requires_grad: + continue + if "layernorm" in name.lower() and "weight" in name: + # LayerNorm weights usually initialized to 1 + param.data.fill_(1.0) + elif "bias" in name: + param.data.zero_() + else: + try: + param.data.normal_(mean=0.0, std=std) + except Exception as e: + print(f"Skipping init for {name} due to error: {e}") + + +def compute_mup_vector(config): + """ + Computes the MuP vector based on model configuration. + + FalconH1 applies different MuP multiplier for each dimension of the hidden states. + The MuP vector is partitioned into chunks, and each chunk is multiplied with its + corresponding projected dimension. + + Args: + config: FalconH1Config object + + Returns: + torch.Tensor: The computed MuP vector + """ + # We'll need some values from the config to compute the vector dimensions + intermediate_size = ( + config.mamba_d_ssm if config.mamba_d_ssm is not None else int(config.mamba_expand * config.hidden_size) + ) + groups_time_state_size = config.mamba_n_groups * config.mamba_d_state + num_heads = config.mamba_n_heads + zxbcdt_multipliers = config.ssm_multipliers + + vector_shape = 2 * intermediate_size + 2 * groups_time_state_size + num_heads + mup_vector = torch.ones(1, 1, vector_shape) + + # Apply multipliers to different sections of the vector + mup_vector[:, :, :intermediate_size] *= zxbcdt_multipliers[0] + mup_vector[:, :, intermediate_size : 2 * intermediate_size] *= zxbcdt_multipliers[1] + mup_vector[:, :, 2 * intermediate_size : 2 * intermediate_size + groups_time_state_size] *= zxbcdt_multipliers[2] + mup_vector[ + :, :, 2 * intermediate_size + groups_time_state_size : 2 * intermediate_size + 2 * groups_time_state_size + ] *= zxbcdt_multipliers[3] + mup_vector[:, :, 2 * intermediate_size + 2 * groups_time_state_size :] *= zxbcdt_multipliers[4] + + return mup_vector + + +@auto_docstring +# Adapted from transformers.models.jamba.modeling_jamba.JambaModel +class FalconH1Model(FalconH1PreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`FalconH1DecoderLayer`] + + Args: + config: FalconH1Config + """ + + def __init__(self, config: FalconH1Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + decoder_layers = [] + for i in range(config.num_hidden_layers): + decoder_layers.append(FalconH1DecoderLayer(config, layer_idx=i)) + self.layers = nn.ModuleList(decoder_layers) + + self._attn_implementation = config._attn_implementation + self.final_layernorm = FalconH1RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = FalconH1RotaryEmbedding(config=config) + + self.embedding_multiplier = config.embedding_multiplier + self.lm_head_multiplier = config.lm_head_multiplier + + self.gradient_checkpointing = False + # Compute the MuP vector once and register it for all layers + mup_vector = compute_mup_vector(config) + for layer in self.layers: + layer.mamba.register_buffer("mup_vector", mup_vector, persistent=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[FalconHybridMambaAttentionDynamicCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, # NOOP kwargs, for now + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) * self.embedding_multiplier + hidden_states = inputs_embeds + + if use_cache and past_key_values is None: + logger.warning_once( + "FalconH1 requires an initialized `FalconHybridMambaAttentionDynamicCache` to return a cache. None was " + "provided, so no cache will be returned." + ) + + if cache_position is None: + cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + mamba_mask = self._update_mamba_mask(attention_mask, cache_position) + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + mamba_attention_mask=mamba_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + if layer_outputs[1] is not None: + # append attentions only of attention layers. Mamba layers return `None` as the attention weights + all_self_attns += (layer_outputs[1],) + + hidden_states = self.final_layernorm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if past_key_values and not past_key_values.has_previous_state: + past_key_values.has_previous_state = True + + next_cache = None if not use_cache else past_key_values + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_mamba_mask(self, attention_mask, cache_position): + """ + No need for zeroing states when + 1. Cached forward + 2. Attending to all inputs + """ + mamba_mask = attention_mask + if cache_position[0] > 0 or (attention_mask is not None and torch.all(attention_mask == 1)): + mamba_mask = None + return mamba_mask + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: FalconHybridMambaAttentionDynamicCache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype = input_tensor.dtype + sequence_length = input_tensor.shape[1] + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_attention_mask = (attention_mask[:, None, None, :] == attention_mask[:, None, :, None])[ + :, :, -sequence_length:, : + ].to(dtype) + padding_mask = causal_mask[:, :, :, :mask_length] + padding_attention_mask + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +class FalconH1ForCausalLM(LlamaForCausalLM): + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + @auto_docstring + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[FalconHybridMambaAttentionDynamicCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, FalconH1ForCausalLM + + >>> model = FalconH1ForCausalLM.from_pretrained("...") + >>> tokenizer = AutoTokenizer.from_pretrained("...") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) * self.model.lm_head_multiplier + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + **kwargs, + ): + # Overwitten -- has a unique cache type, `FalconHybridMambaAttentionDynamicCache` + + empty_past_kv = past_key_values is None + + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. + # (we can't check exception 3 while compiling) + if not empty_past_kv: + if ( + inputs_embeds is not None # Exception 1 + or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3 + ): + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + else: + past_key_values = FalconHybridMambaAttentionDynamicCache( + self.config, + input_ids.shape[0], + self.dtype, + devices=[ + self.model.layers[i].mamba.conv1d.weight.device for i in range(self.config.num_hidden_layers) + ], + ) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if not empty_past_kv: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and empty_past_kv: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "logits_to_keep": self.config.num_logits_to_keep, + "cache_position": cache_position, + } + ) + return model_inputs + + +__all__ = ["FalconH1Model", "FalconH1ForCausalLM", "FalconH1PreTrainedModel"] diff --git a/tests/models/falcon_h1/__init__.py b/tests/models/falcon_h1/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/models/falcon_h1/test_modeling_falcon_h1.py b/tests/models/falcon_h1/test_modeling_falcon_h1.py new file mode 100644 index 00000000000..b2c0af8a022 --- /dev/null +++ b/tests/models/falcon_h1/test_modeling_falcon_h1.py @@ -0,0 +1,518 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch FalconH1 model.""" + +import inspect +import unittest + +import pytest + +from transformers import FalconH1Config, is_torch_available +from transformers.testing_utils import ( + require_torch, + require_torch_gpu, + slow, + torch_device, +) + +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, ids_tensor +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_torch_available(): + import torch + + from transformers import AutoTokenizer, FalconH1ForCausalLM, FalconH1Model + from transformers.models.falcon_h1.modeling_falcon_h1 import ( + FalconHybridMambaAttentionDynamicCache, + ) + + +class FalconH1ModelTester: + def __init__( + self, + parent, + batch_size=13, + seq_length=7, + is_training=True, + use_input_mask=True, + use_labels=True, + vocab_size=99, + hidden_size=32, + num_hidden_layers=4, + num_attention_heads=4, + num_key_value_heads=2, + intermediate_size=64, + hidden_act="silu", + attention_dropout=0.0, + attn_layer_indices=None, + attn_rotary_emb=8, + max_position_embeddings=512, + type_vocab_size=16, + initializer_range=0.02, + num_labels=3, + pad_token_id=0, + mamba_n_groups=1, + mamba_n_heads=16, + mamba_d_state=16, + mamba_d_conv=4, + mamba_expand=2, + mamba_chunk_size=16, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_input_mask = use_input_mask + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.attention_dropout = attention_dropout + self.attn_layer_indices = attn_layer_indices + self.attn_rotary_emb = attn_rotary_emb + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.num_labels = num_labels + self.pad_token_id = pad_token_id + self.scope = scope + self.mamba_n_groups = mamba_n_groups + self.mamba_n_heads = mamba_n_heads + self.mamba_d_state = mamba_d_state + self.mamba_d_conv = mamba_d_conv + self.mamba_expand = mamba_expand + self.mamba_chunk_size = mamba_chunk_size + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device)) + + token_labels = None + if self.use_labels: + token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) + + config = self.get_config() + + return config, input_ids, input_mask, token_labels + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + input_mask, + token_labels, + ) = config_and_inputs + inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} + return config, inputs_dict + + def get_config(self): + # Fix for SDPA tests, force at least 4 layers + if self.num_hidden_layers < 4: + self.num_hidden_layers = 4 + if self.attn_layer_indices is None: + d = [x for x in range(2, self.num_hidden_layers) if self.num_hidden_layers % x == 0] + if len(d) == 0: + raise ValueError("num_hidden_layers is prime, cannot automatically set attn_layer_indices.") + d = d[-1] # get the largest divisor + self.attn_layer_indices = [x + 1 for x in range(0, self.num_hidden_layers, d)] + + return FalconH1Config( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + num_key_value_heads=self.num_key_value_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + attention_dropout=self.attention_dropout, + attn_layer_indices=self.attn_layer_indices, + attn_rotary_emb=self.attn_rotary_emb, + max_position_embeddings=self.max_position_embeddings, + initializer_range=self.initializer_range, + pad_token_id=self.pad_token_id, + mamba_n_groups=self.mamba_n_groups, + mamba_n_heads=self.mamba_n_heads, + mamba_d_state=self.mamba_d_state, + mamba_d_conv=self.mamba_d_conv, + mamba_expand=self.mamba_expand, + mamba_chunk_size=self.mamba_chunk_size, + ) + + def create_and_check_model( + self, + config, + input_ids, + input_mask, + token_labels, + ): + model = FalconH1Model(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask) + result = model(input_ids) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + def create_and_check_for_causal_lm( + self, + config, + input_ids, + input_mask, + token_labels, + ): + model = FalconH1ForCausalLM(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask, labels=token_labels) + result = model(input_ids, attention_mask=input_mask) + result = model(input_ids, labels=token_labels) + result = model(input_ids) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + + def create_and_check_decoder_model_past_large_inputs( + self, + config, + input_ids, + input_mask, + token_labels, + ): + # config.is_decoder = True + # config.add_cross_attention = True + model = FalconH1ForCausalLM(config=config) + model.to(torch_device) + model.eval() + + # first forward pass + # Attention: Jamba needs the cache to be initialized to return a cache! + past_key_values = FalconHybridMambaAttentionDynamicCache( + config, + input_ids.shape[0], + model.dtype, + devices=[model.device for _ in range(model.config.num_hidden_layers)], + ) + outputs = model( + input_ids, + attention_mask=input_mask, + past_key_values=past_key_values, + use_cache=True, + ) + past_key_values = outputs.past_key_values + + # create hypothetical multiple next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) + next_mask = ids_tensor((self.batch_size, 3), vocab_size=2) + + # append to next input_ids and + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + next_attention_mask = torch.cat([input_mask, next_mask], dim=-1) + + output_from_no_past = model( + next_input_ids, + attention_mask=next_attention_mask, + output_hidden_states=True, + )["hidden_states"][0] + output_from_past = model( + next_tokens, + attention_mask=next_attention_mask, + past_key_values=past_key_values, + output_hidden_states=True, + cache_position=torch.arange( + input_ids.shape[1], input_ids.shape[1] + next_tokens.shape[1], device=model.device + ), + )["hidden_states"][0] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, :, random_slice_idx].detach() + + self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + + +@require_torch +class FalconH1ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = (FalconH1Model, FalconH1ForCausalLM) if is_torch_available() else () + test_headmasking = False + test_pruning = False + fx_compatible = False + + # Need to use `0.8` instead of `0.9` for `test_cpu_offload` + # This is because we are hitting edge cases with the causal_mask buffer + model_split_percents = [0.5, 0.7, 0.8] + + pipeline_model_mapping = ( + {"feature-extraction": FalconH1Model, "text-generation": FalconH1ForCausalLM} if is_torch_available() else {} + ) + + def setUp(self): + self.model_tester = FalconH1ModelTester(self) + self.config_tester = ConfigTester(self, config_class=FalconH1Config, hidden_size=64) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_for_casual_lm(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_causal_lm(*config_and_inputs) + + def test_decoder_model_past_with_large_inputs(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) + + # def test_initialization(self): + # r""" + # Overriding the test_initialization test as the A_log and D params of the FalconH1 mixer are initialized differently + # """ + # config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + # configs_no_init = _config_zero_init(config) + # for model_class in self.all_model_classes: + # model = model_class(config=configs_no_init) + # for name, param in model.named_parameters(): + # if param.requires_grad: + # if "A_log" in name: + # A = torch.arange(1, config.mamba_n_heads + 1, dtype=torch.float32) + # torch.testing.assert_close(param.data, torch.log(A), rtol=1e-5, atol=1e-5) + # elif "D" in name: + # D = torch.ones(config.mamba_n_heads, dtype=torch.float32) + # torch.testing.assert_close(param.data, D, rtol=1e-5, atol=1e-5) + # else: + # self.assertIn( + # ((param.data.mean() * 1e9).round() / 1e9).item(), + # [0.0, 1.0], + # msg=f"Parameter {name} of model {model_class} seems not properly initialized", + # ) + + def test_mismatched_shapes_have_properly_initialized_weights(self): + r""" + Overriding the test_mismatched_shapes_have_properly_initialized_weights test because A_log and D params of the + FalconH1 mixer are initialized differently and we tested that in test_initialization + """ + self.skipTest(reason="Cumbersome and redundant for FalconH1") + + def test_attention_outputs(self): + r""" + Overriding the test_attention_outputs test as the FalconH1 model outputs attention only for its attention layers + """ + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + seq_len = getattr(self.model_tester, "seq_length", None) + encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len) + encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length) + + expected_num_attentions = self.model_tester.num_hidden_layers + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + config.return_dict = True + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.attentions + self.assertEqual(len(attentions), expected_num_attentions) + + # check that output_attentions also work using config + del inputs_dict["output_attentions"] + config.output_attentions = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.attentions + self.assertEqual(len(attentions), expected_num_attentions) + + self.assertListEqual( + list(attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], + ) + out_len = len(outputs) + + # Check attention is always last and order is fine + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + added_hidden_states = 1 + self.assertEqual(out_len + added_hidden_states, len(outputs)) + + self_attentions = outputs.attentions + + self.assertEqual(len(self_attentions), expected_num_attentions) + self.assertListEqual( + list(self_attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], + ) + + def test_batching_equivalence(self): + # need to disable the tril input mask + orig = self.model_tester.use_input_mask + self.model_tester.use_input_mask = False + super().test_batching_equivalence() + self.model_tester.use_input_mask = orig + + # essentially the same test in test_utils, just adjustment for rtol for this model + @pytest.mark.generate + def test_left_padding_compatibility(self): + # NOTE: left-padding results in small numerical differences. This is expected. + # See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535 + + # First, filter out models that don't support left padding + # - The model must have generative capabilities + if len(self.all_generative_model_classes) == 0: + self.skipTest(reason="No generative architecture available for this model.") + + # - The model must support padding + if not self.has_attentions: + self.skipTest(reason="This model doesn't support padding.") + + # - The model must be a decoder-only architecture (encoder-based architectures use right-padding) + decoder_only_classes = [] + for model_class in self.all_generative_model_classes: + config, _ = self.prepare_config_and_inputs_for_generate() + if config.is_encoder_decoder: + continue + else: + decoder_only_classes.append(model_class) + if len(decoder_only_classes) == 0: + self.skipTest(reason="No decoder-only architecture available for this model.") + + # - Decoder-only architectures derived from encoder-decoder models could support it in theory, but we haven't + # added support for it yet. We skip these models for now. + has_encoder_attributes = any( + attr_name + for attr_name in config.to_dict().keys() + if attr_name.startswith("encoder") and attr_name != "encoder_no_repeat_ngram_size" + ) + if has_encoder_attributes: + self.skipTest( + reason="The decoder-only derived from encoder-decoder models are not expected to support left-padding." + ) + + # Then, test left-padding + def _prepare_model_kwargs(input_ids, attention_mask, signature): + model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask} + if "position_ids" in signature: + position_ids = torch.cumsum(attention_mask, dim=-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + model_kwargs["position_ids"] = position_ids + if "cache_position" in signature: + cache_position = torch.arange(input_ids.shape[-1], device=torch_device) + model_kwargs["cache_position"] = cache_position + return model_kwargs + + for model_class in decoder_only_classes: + config, inputs_dict = self.prepare_config_and_inputs_for_generate() + input_ids = inputs_dict["input_ids"] + + # - for left padding we absolutely need to use an all ones + # attention mask, so we do not use the one in inputs_dict + attention_mask = torch.ones_like(input_ids) + + model = model_class(config).to(torch_device).eval() + signature = inspect.signature(model.forward).parameters.keys() + + # no cache as some models require special cache classes to be init outside forward + model.generation_config.use_cache = False + + # Without padding + model_kwargs = _prepare_model_kwargs(input_ids, attention_mask, signature) + next_logits_wo_padding = model(**model_kwargs).logits[:, -1, :] + + # With left-padding (length 32) + # can hardcode pad_token to be 0 as we'll do attn masking anyway + pad_token_id = ( + config.get_text_config().pad_token_id if config.get_text_config().pad_token_id is not None else 0 + ) + pad_size = (input_ids.shape[0], 32) + padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * pad_token_id + padded_input_ids = torch.cat((padding, input_ids), dim=1) + padded_attention_mask = torch.cat((torch.zeros_like(padding), attention_mask), dim=1) + model_kwargs = _prepare_model_kwargs(padded_input_ids, padded_attention_mask, signature) + next_logits_with_padding = model(**model_kwargs).logits[:, -1, :] + + # They should result in very similar logits + torch.testing.assert_close(next_logits_wo_padding, next_logits_with_padding, rtol=1e-5, atol=1e-5) + + +@slow +@require_torch +@require_torch_gpu +class FalconH1ModelIntegrationTest(unittest.TestCase): + @slow + def test_llama_3_1_hard(self): + """ + An integration test for Falcon-H1. + """ + EXPECTED_TEXT = ( + "Tell me about the french revolution.\n" + "The French Revolution (1789–1799) was a period of radical social and political upheaval in France that " + "fundamentally transformed the nation and had profound effects on the rest of Europe and the world. Here are the key aspects of the revolution:\n\n" + "### **Causes**\n" + "1. **Economic Crisis**: France was in severe financial trouble due to costly wars (particularly the American Revolution), extravagant spending by the monarchy, and inefficient taxation.\n" + "2. **Social Inequality**: The rigid class system (the Ancien RΓ©gime) divided society into the privileged nobility and clergy (First Estate) and the common people (Third Estate), who bore the brunt of taxation and had few rights.\n" + "3. **Enlightenment Ideas**: Philosophers like Rousseau, Voltaire, and Montesquieu inspired ideas of liberty, equality, and popular sovereignty.\n" + "4. **Settlement of 1789**: The Estates-General convened to address the financial crisis, leading to the Third Estate's assertion of its rights and the eventual formation of the National Assembly.\n\n" + "### **Key Events**\n" + "1. **Opening of the Revolution (1789)**:\n" + "- **Storming of the Bastille**: Symbolic of the fall of royal tyranny.\n" + "- **Declaration of the Rights of Man and of the Citizen**: Proclaimed universal rights to liberty, property, and security.\n" + "- **Creation of the National Assembly**: The Third Estate declared itself the representative body of France.\n\n" + "2. **Radical Phase (1792–1794)**:\n" + "- **Reign of Terror**: Led by Maximilien Robespierre, the Committee of Public Safety enforced radical egalitarianism through the guillotine, executing thousands of perceived enemies of the revolution (monarchists, clergy, aristocrats, and counter-revolutionaries).\n" + "- **Execution of Louis XVI**: The king was guillotined in June 1793, symbolizing the end of the monarchy.\n" + ) + + model_id = "tiiuae/Falcon-H1-1.5B-Deep-Instruct" + tokenizer = AutoTokenizer.from_pretrained(model_id) + model = FalconH1ForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto") + device = "cuda" + messages = [{"role": "user", "content": "Tell me about the french revolution."}] + input_text = tokenizer.apply_chat_template(messages, tokenize=False) + inputs = tokenizer.encode(input_text, return_tensors="pt").to(device) + + with torch.no_grad(): + outputs = model.generate(inputs, max_new_tokens=512, do_sample=False) + + generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) + + self.assertEqual(generated_text, EXPECTED_TEXT)