mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
Merge 8afc9cadc8
into ebfbcd42da
This commit is contained in:
commit
0e5f71f052
155
docs/source/en/model_doc/modernbert-decoder.md
Normal file
155
docs/source/en/model_doc/modernbert-decoder.md
Normal file
@ -0,0 +1,155 @@
|
|||||||
|
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||||
|
the License. You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||||
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||||
|
specific language governing permissions and limitations under the License.
|
||||||
|
|
||||||
|
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||||
|
rendered properly in your Markdown viewer.
|
||||||
|
|
||||||
|
-->
|
||||||
|
|
||||||
|
<div style="float: right;">
|
||||||
|
<div class="flex flex-wrap space-x-1">
|
||||||
|
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||||
|
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||||
|
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
# ModernBERT Decoder
|
||||||
|
|
||||||
|
ModernBERT Decoder is the same architecture as [ModernBERT](https://huggingface.co/papers/2412.13663) but trained from scratch with a causal language modeling (CLM) objective. This allows for using the same architecture for comparing encoders and decoders. This is the decoder architecture implementation of ModernBERT, designed for autoregressive text generation tasks.
|
||||||
|
|
||||||
|
Like the encoder version, ModernBERT Decoder incorporates modern architectural improvements such as rotary positional embeddings to support sequences of up to 8192 tokens, unpadding to avoid wasting compute on padding tokens, GeGLU layers, and alternating attention patterns. However, it uses causal (unidirectional) attention to enable autoregressive generation.
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> Click on the ModernBERT Decoder models in the right sidebar for more examples of how to apply ModernBERT Decoder to different text generation tasks.
|
||||||
|
|
||||||
|
The example below demonstrates how to use ModernBERT Decoder for text generation with [`Pipeline`], [`AutoModel`], and from the command line.
|
||||||
|
|
||||||
|
<hfoptions id="usage">
|
||||||
|
<hfoption id="Pipeline">
|
||||||
|
|
||||||
|
```py
|
||||||
|
import torch
|
||||||
|
from transformers import pipeline
|
||||||
|
|
||||||
|
generator = pipeline(
|
||||||
|
task="text-generation",
|
||||||
|
model="blab-jhu/test-32m-dec",
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
device=0
|
||||||
|
)
|
||||||
|
generator("The future of artificial intelligence is", max_length=50, num_return_sequences=1)
|
||||||
|
|
||||||
|
# For sequence classification
|
||||||
|
classifier = pipeline(
|
||||||
|
task="text-classification",
|
||||||
|
model="blab-jhu/test-32m-dec",
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
device=0
|
||||||
|
)
|
||||||
|
classifier("This movie is really great!")
|
||||||
|
```
|
||||||
|
|
||||||
|
</hfoption>
|
||||||
|
<hfoption id="AutoModel">
|
||||||
|
|
||||||
|
```py
|
||||||
|
import torch
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("blab-jhu/test-32m-dec")
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
"blab-jhu/test-32m-dec",
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
device_map="auto",
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = "The future of artificial intelligence is"
|
||||||
|
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model.generate(
|
||||||
|
**inputs,
|
||||||
|
max_length=50,
|
||||||
|
num_return_sequences=1,
|
||||||
|
temperature=0.7,
|
||||||
|
do_sample=True,
|
||||||
|
pad_token_id=tokenizer.eos_token_id
|
||||||
|
)
|
||||||
|
|
||||||
|
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||||
|
print(f"Generated text: {generated_text}")
|
||||||
|
|
||||||
|
# For sequence classification
|
||||||
|
from transformers import AutoModelForSequenceClassification
|
||||||
|
|
||||||
|
classifier_model = AutoModelForSequenceClassification.from_pretrained(
|
||||||
|
"blab-jhu/test-32m-dec",
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
device_map="auto",
|
||||||
|
num_labels=2
|
||||||
|
)
|
||||||
|
|
||||||
|
text = "This movie is really great!"
|
||||||
|
inputs = tokenizer(text, return_tensors="pt").to("cuda")
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = classifier_model(**inputs)
|
||||||
|
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
|
||||||
|
predicted_class = torch.argmax(predictions, dim=-1)
|
||||||
|
|
||||||
|
print(f"Predicted class: {predicted_class.item()}")
|
||||||
|
print(f"Prediction probabilities: {predictions}")
|
||||||
|
```
|
||||||
|
|
||||||
|
</hfoption>
|
||||||
|
<hfoption id="transformers CLI">
|
||||||
|
|
||||||
|
```bash
|
||||||
|
echo "The future of artificial intelligence is" | transformers run --task text-generation --model your-username/modernbert-decoder-base --device 0
|
||||||
|
```
|
||||||
|
|
||||||
|
</hfoption>
|
||||||
|
</hfoptions>
|
||||||
|
|
||||||
|
## ModernBertDecoderConfig
|
||||||
|
|
||||||
|
[[autodoc]] ModernBertDecoderConfig
|
||||||
|
|
||||||
|
<frameworkcontent>
|
||||||
|
<pt>
|
||||||
|
|
||||||
|
## ModernBertDecoderModel
|
||||||
|
|
||||||
|
[[autodoc]] ModernBertDecoderModel
|
||||||
|
- forward
|
||||||
|
|
||||||
|
## ModernBertDecoderForCausalLM
|
||||||
|
|
||||||
|
[[autodoc]] ModernBertDecoderForCausalLM
|
||||||
|
- forward
|
||||||
|
|
||||||
|
## ModernBertDecoderForSequenceClassification
|
||||||
|
|
||||||
|
[[autodoc]] ModernBertDecoderForSequenceClassification
|
||||||
|
- forward
|
||||||
|
|
||||||
|
### Usage tips
|
||||||
|
|
||||||
|
The ModernBertDecoder model can be fine-tuned for various text generation tasks using the HuggingFace Transformers library. It supports efficient inference with features like:
|
||||||
|
|
||||||
|
- **Causal attention**: Ensures autoregressive generation by masking future tokens
|
||||||
|
- **Sliding window attention**: Alternates between local and global attention patterns for efficiency
|
||||||
|
- **Rotary positional embeddings**: Enables handling of longer sequences up to 8000 tokens
|
||||||
|
- **FlashAttention support**: Optimized attention computation for faster training and inference
|
||||||
|
|
||||||
|
</pt>
|
||||||
|
</frameworkcontent>
|
@ -204,6 +204,7 @@ if TYPE_CHECKING:
|
|||||||
from .mobilevit import *
|
from .mobilevit import *
|
||||||
from .mobilevitv2 import *
|
from .mobilevitv2 import *
|
||||||
from .modernbert import *
|
from .modernbert import *
|
||||||
|
from .modernbert_decoder import *
|
||||||
from .moonshine import *
|
from .moonshine import *
|
||||||
from .moshi import *
|
from .moshi import *
|
||||||
from .mpnet import *
|
from .mpnet import *
|
||||||
|
@ -236,6 +236,7 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
|
|||||||
("mobilevit", "MobileViTConfig"),
|
("mobilevit", "MobileViTConfig"),
|
||||||
("mobilevitv2", "MobileViTV2Config"),
|
("mobilevitv2", "MobileViTV2Config"),
|
||||||
("modernbert", "ModernBertConfig"),
|
("modernbert", "ModernBertConfig"),
|
||||||
|
("modernbert-decoder", "ModernBertDecoderConfig"),
|
||||||
("moonshine", "MoonshineConfig"),
|
("moonshine", "MoonshineConfig"),
|
||||||
("moshi", "MoshiConfig"),
|
("moshi", "MoshiConfig"),
|
||||||
("mpnet", "MPNetConfig"),
|
("mpnet", "MPNetConfig"),
|
||||||
@ -630,6 +631,7 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str](
|
|||||||
("mobilevit", "MobileViT"),
|
("mobilevit", "MobileViT"),
|
||||||
("mobilevitv2", "MobileViTV2"),
|
("mobilevitv2", "MobileViTV2"),
|
||||||
("modernbert", "ModernBERT"),
|
("modernbert", "ModernBERT"),
|
||||||
|
("modernbert-decoder", "ModernBertDecoder"),
|
||||||
("moonshine", "Moonshine"),
|
("moonshine", "Moonshine"),
|
||||||
("moshi", "Moshi"),
|
("moshi", "Moshi"),
|
||||||
("mpnet", "MPNet"),
|
("mpnet", "MPNet"),
|
||||||
|
@ -225,6 +225,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
|||||||
("mobilevit", "MobileViTModel"),
|
("mobilevit", "MobileViTModel"),
|
||||||
("mobilevitv2", "MobileViTV2Model"),
|
("mobilevitv2", "MobileViTV2Model"),
|
||||||
("modernbert", "ModernBertModel"),
|
("modernbert", "ModernBertModel"),
|
||||||
|
("modernbert-decoder", "ModernBertDecoderModel"),
|
||||||
("moonshine", "MoonshineModel"),
|
("moonshine", "MoonshineModel"),
|
||||||
("moshi", "MoshiModel"),
|
("moshi", "MoshiModel"),
|
||||||
("mpnet", "MPNetModel"),
|
("mpnet", "MPNetModel"),
|
||||||
@ -621,6 +622,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
|||||||
("mistral", "MistralForCausalLM"),
|
("mistral", "MistralForCausalLM"),
|
||||||
("mixtral", "MixtralForCausalLM"),
|
("mixtral", "MixtralForCausalLM"),
|
||||||
("mllama", "MllamaForCausalLM"),
|
("mllama", "MllamaForCausalLM"),
|
||||||
|
("modernbert-decoder", "ModernBertDecoderForCausalLM"),
|
||||||
("moshi", "MoshiForCausalLM"),
|
("moshi", "MoshiForCausalLM"),
|
||||||
("mpt", "MptForCausalLM"),
|
("mpt", "MptForCausalLM"),
|
||||||
("musicgen", "MusicgenForCausalLM"),
|
("musicgen", "MusicgenForCausalLM"),
|
||||||
@ -1144,6 +1146,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
|||||||
("mixtral", "MixtralForSequenceClassification"),
|
("mixtral", "MixtralForSequenceClassification"),
|
||||||
("mobilebert", "MobileBertForSequenceClassification"),
|
("mobilebert", "MobileBertForSequenceClassification"),
|
||||||
("modernbert", "ModernBertForSequenceClassification"),
|
("modernbert", "ModernBertForSequenceClassification"),
|
||||||
|
("modernbert-decoder", "ModernBertDecoderForSequenceClassification"),
|
||||||
("mpnet", "MPNetForSequenceClassification"),
|
("mpnet", "MPNetForSequenceClassification"),
|
||||||
("mpt", "MptForSequenceClassification"),
|
("mpt", "MptForSequenceClassification"),
|
||||||
("mra", "MraForSequenceClassification"),
|
("mra", "MraForSequenceClassification"),
|
||||||
|
27
src/transformers/models/modernbert_decoder/__init__.py
Normal file
27
src/transformers/models/modernbert_decoder/__init__.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from ...utils import _LazyModule
|
||||||
|
from ...utils.import_utils import define_import_structure
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .configuration_modernbert_decoder import *
|
||||||
|
from .modeling_modernbert_decoder import *
|
||||||
|
else:
|
||||||
|
import sys
|
||||||
|
|
||||||
|
_file = globals()["__file__"]
|
||||||
|
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
@ -0,0 +1,222 @@
|
|||||||
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
|
# This file was automatically generated from src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py.
|
||||||
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||||
|
# the file from the modular. If any change should be done, please apply the change to the
|
||||||
|
# modular_modernbert_decoder.py file directly. One of our CI enforces this.
|
||||||
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
|
# Copyright 2024 Answer.AI, LightOn, and contributors, and the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
|
class ModernBertDecoderConfig(PretrainedConfig):
|
||||||
|
r"""
|
||||||
|
This is the configuration class to store the configuration of a [`ModernBertDecoderModel`]. It is used to instantiate a ModernBert
|
||||||
|
decoder model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||||
|
defaults will yield a similar configuration to that of the ModernBERT-base decoder.
|
||||||
|
e.g. [blab-jhu/test-32m-dec](https://huggingface.co/blab-jhu/test-32m-dec)
|
||||||
|
|
||||||
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vocab_size (`int`, *optional*, defaults to 50368):
|
||||||
|
Vocabulary size of the ModernBert decoder model. Defines the number of different tokens that can be represented by the
|
||||||
|
`inputs_ids` passed when calling [`ModernBertDecoderModel`]
|
||||||
|
hidden_size (`int`, *optional*, defaults to 768):
|
||||||
|
Dimension of the hidden representations.
|
||||||
|
intermediate_size (`int`, *optional*, defaults to 1152):
|
||||||
|
Dimension of the MLP representations.
|
||||||
|
num_hidden_layers (`int`, *optional*, defaults to 22):
|
||||||
|
Number of hidden layers in the Transformer decoder.
|
||||||
|
num_attention_heads (`int`, *optional*, defaults to 12):
|
||||||
|
Number of attention heads for each attention layer in the Transformer decoder.
|
||||||
|
hidden_activation (`str` or `function`, *optional*, defaults to `"gelu"`):
|
||||||
|
The non-linear activation function (function or string) in the decoder. Will default to `"gelu"`
|
||||||
|
if not specified.
|
||||||
|
max_position_embeddings (`int`, *optional*, defaults to 8192):
|
||||||
|
The maximum sequence length that this model might ever be used with.
|
||||||
|
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||||
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||||
|
initializer_cutoff_factor (`float`, *optional*, defaults to 2.0):
|
||||||
|
The cutoff factor for the truncated_normal_initializer for initializing all weight matrices.
|
||||||
|
norm_eps (`float`, *optional*, defaults to 1e-05):
|
||||||
|
The epsilon used by the rms normalization layers.
|
||||||
|
norm_bias (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to use bias in the normalization layers.
|
||||||
|
pad_token_id (`int`, *optional*, defaults to 50283):
|
||||||
|
Padding token id.
|
||||||
|
eos_token_id (`int`, *optional*, defaults to 50282):
|
||||||
|
End of stream token id.
|
||||||
|
bos_token_id (`int`, *optional*, defaults to 50281):
|
||||||
|
Beginning of stream token id.
|
||||||
|
cls_token_id (`int`, *optional*, defaults to 50281):
|
||||||
|
Classification token id.
|
||||||
|
sep_token_id (`int`, *optional*, defaults to 50282):
|
||||||
|
Separation token id.
|
||||||
|
global_rope_theta (`float`, *optional*, defaults to 160000.0):
|
||||||
|
The base period of the global RoPE embeddings.
|
||||||
|
attention_bias (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||||
|
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||||
|
The dropout ratio for the attention probabilities.
|
||||||
|
embedding_dropout (`float`, *optional*, defaults to 0.0):
|
||||||
|
The dropout ratio for the embeddings.
|
||||||
|
mlp_bias (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to use bias in the MLP layers.
|
||||||
|
mlp_dropout (`float`, *optional*, defaults to 0.0):
|
||||||
|
The dropout ratio for the MLP layers.
|
||||||
|
decoder_bias (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to use bias in the decoder layers.
|
||||||
|
classifier_dropout (`float`, *optional*, defaults to 0.0):
|
||||||
|
The dropout ratio for the classifier.
|
||||||
|
classifier_bias (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to use bias in the classifier.
|
||||||
|
classifier_activation (`str`, *optional*, defaults to `"gelu"`):
|
||||||
|
The activation function for the classifier.
|
||||||
|
deterministic_flash_attn (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to use deterministic flash attention. If `False`, inference will be faster but not deterministic.
|
||||||
|
reference_compile (`bool`, *optional*):
|
||||||
|
Whether to compile the layers of the model which were compiled during pretraining. If `None`, then parts of
|
||||||
|
the model will be compiled if 1) `triton` is installed, 2) the model is not on MPS, 3) the model is not
|
||||||
|
shared between devices, and 4) the model is not resized after initialization. If `True`, then the model may
|
||||||
|
be faster in some scenarios.
|
||||||
|
use_cache (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||||
|
relevant if `config.is_decoder=True`.
|
||||||
|
local_attention (`int`, *optional*, defaults to 128):
|
||||||
|
The sliding window size for local attention. Only used for layers that use local attention. Note that for
|
||||||
|
the decoder to match ModernBERT this is actually half of the sliding window size, so 128 => 64.
|
||||||
|
global_attn_every_n_layers (`int`, *optional*, defaults to 3):
|
||||||
|
Every `global_attn_every_n_layers` layers will use global attention instead of local attention.
|
||||||
|
local_rope_theta (`float`, *optional*):
|
||||||
|
The base period of the local RoPE embeddings. If not specified, uses the same value as `global_rope_theta`.
|
||||||
|
layer_types (`list`, *optional*):
|
||||||
|
List of layer types, one for each layer. If not specified, will be automatically generated based on
|
||||||
|
`global_attn_every_n_layers`. Should contain "full_attention" or "sliding_window".
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import ModernBertDecoderModel, ModernBertDecoderConfig
|
||||||
|
|
||||||
|
>>> # Initializing a ModernBert decoder style configuration
|
||||||
|
>>> configuration = ModernBertDecoderConfig()
|
||||||
|
|
||||||
|
>>> # Initializing a model from the modernbert-base decoder style configuration
|
||||||
|
>>> model = ModernBertDecoderModel(configuration)
|
||||||
|
|
||||||
|
>>> # Accessing the model configuration
|
||||||
|
>>> configuration = model.config
|
||||||
|
```"""
|
||||||
|
|
||||||
|
model_type = "modernbert-decoder"
|
||||||
|
keys_to_ignore_at_inference = ["past_key_values"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=50368,
|
||||||
|
hidden_size=768,
|
||||||
|
intermediate_size=1152,
|
||||||
|
num_hidden_layers=22,
|
||||||
|
num_attention_heads=12,
|
||||||
|
hidden_activation="gelu",
|
||||||
|
max_position_embeddings=8192,
|
||||||
|
initializer_range=0.02,
|
||||||
|
initializer_cutoff_factor=2.0,
|
||||||
|
norm_eps=1e-5,
|
||||||
|
norm_bias=False,
|
||||||
|
pad_token_id=50283,
|
||||||
|
eos_token_id=50282,
|
||||||
|
bos_token_id=50281,
|
||||||
|
cls_token_id=50281,
|
||||||
|
sep_token_id=50282,
|
||||||
|
global_rope_theta=160000.0,
|
||||||
|
attention_bias=False,
|
||||||
|
attention_dropout=0.0,
|
||||||
|
embedding_dropout=0.0,
|
||||||
|
mlp_bias=False,
|
||||||
|
mlp_dropout=0.0,
|
||||||
|
decoder_bias=True,
|
||||||
|
classifier_dropout=0.0,
|
||||||
|
classifier_bias=False,
|
||||||
|
classifier_activation="gelu",
|
||||||
|
deterministic_flash_attn=False,
|
||||||
|
reference_compile=None,
|
||||||
|
use_cache=True,
|
||||||
|
local_attention=128,
|
||||||
|
global_attn_every_n_layers=3,
|
||||||
|
local_rope_theta=None,
|
||||||
|
layer_types=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
cls_token_id=cls_token_id,
|
||||||
|
sep_token_id=sep_token_id,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.initializer_cutoff_factor = initializer_cutoff_factor
|
||||||
|
self.norm_eps = norm_eps
|
||||||
|
self.norm_bias = norm_bias
|
||||||
|
self.global_rope_theta = global_rope_theta
|
||||||
|
self.attention_bias = attention_bias
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
self.hidden_activation = hidden_activation
|
||||||
|
self.embedding_dropout = embedding_dropout
|
||||||
|
self.mlp_bias = mlp_bias
|
||||||
|
self.mlp_dropout = mlp_dropout
|
||||||
|
self.decoder_bias = decoder_bias
|
||||||
|
self.classifier_dropout = classifier_dropout
|
||||||
|
self.classifier_bias = classifier_bias
|
||||||
|
self.classifier_activation = classifier_activation
|
||||||
|
self.deterministic_flash_attn = deterministic_flash_attn
|
||||||
|
self.reference_compile = reference_compile
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.local_attention = local_attention
|
||||||
|
self.global_attn_every_n_layers = global_attn_every_n_layers
|
||||||
|
self.local_rope_theta = local_rope_theta
|
||||||
|
|
||||||
|
# Set up layer_types for standardized layer type detection
|
||||||
|
self.layer_types = layer_types
|
||||||
|
if self.layer_types is None:
|
||||||
|
# Create layer_types based on the alternating pattern
|
||||||
|
self.layer_types = []
|
||||||
|
for layer_id in range(num_hidden_layers):
|
||||||
|
if layer_id % global_attn_every_n_layers != 0:
|
||||||
|
self.layer_types.append("sliding_attention")
|
||||||
|
else:
|
||||||
|
self.layer_types.append("full_attention")
|
||||||
|
|
||||||
|
self.sliding_window = local_attention
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
output = super().to_dict()
|
||||||
|
output.pop("reference_compile", None)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["ModernBertDecoderConfig"]
|
@ -0,0 +1,777 @@
|
|||||||
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
|
# This file was automatically generated from src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py.
|
||||||
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||||
|
# the file from the modular. If any change should be done, please apply the change to the
|
||||||
|
# modular_modernbert_decoder.py file directly. One of our CI enforces this.
|
||||||
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
|
# Copyright 2024 Answer.AI, LightOn, and contributors, and the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
|
from ...cache_utils import Cache, DynamicCache
|
||||||
|
from ...generation import GenerationMixin
|
||||||
|
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
||||||
|
from ...modeling_layers import GradientCheckpointingLayer
|
||||||
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
||||||
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||||
|
from ...models.modernbert.modeling_modernbert import (
|
||||||
|
ModernBertEmbeddings,
|
||||||
|
ModernBertMLP,
|
||||||
|
ModernBertPredictionHead,
|
||||||
|
ModernBertPreTrainedModel,
|
||||||
|
ModernBertRotaryEmbedding,
|
||||||
|
apply_rotary_pos_emb,
|
||||||
|
)
|
||||||
|
from ...utils import auto_docstring, logging
|
||||||
|
from .configuration_modernbert_decoder import ModernBertDecoderConfig
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def eager_attention_forward(
|
||||||
|
module: "ModernBertDecoderAttention",
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor],
|
||||||
|
dropout: float = 0.0,
|
||||||
|
scaling: Optional[float] = None,
|
||||||
|
sliding_window: Optional[int] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
"""A simple eager attention implementation for ModernBERT decoder."""
|
||||||
|
if scaling is None:
|
||||||
|
scaling = module.head_dim**-0.5
|
||||||
|
|
||||||
|
# Compute attention scores
|
||||||
|
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||||
|
|
||||||
|
# Use the pre-computed attention mask
|
||||||
|
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||||
|
attn_weights = attn_weights + causal_mask
|
||||||
|
|
||||||
|
# upcast attention to fp32
|
||||||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||||
|
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||||
|
attn_output = torch.matmul(attn_weights, value)
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
|
||||||
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
||||||
|
class ModernBertDecoderAttention(nn.Module):
|
||||||
|
"""Performs causal multi-headed self attention for ModernBERT decoder.
|
||||||
|
|
||||||
|
It supports both local attention (sliding window) and global attention patterns.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: ModernBertDecoderConfig, layer_idx: Optional[int] = None):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
self.is_causal = True
|
||||||
|
|
||||||
|
if config.hidden_size % config.num_attention_heads != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention heads ({config.num_attention_heads})"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.attention_dropout = config.attention_dropout
|
||||||
|
self.deterministic_flash_attn = config.deterministic_flash_attn
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.head_dim = config.hidden_size // config.num_attention_heads
|
||||||
|
self.all_head_size = self.head_dim * self.num_heads
|
||||||
|
self.scaling = self.head_dim**-0.5
|
||||||
|
|
||||||
|
# NOTE: this is different than ModernBERT (separated QKV) so be sure to adapt to this
|
||||||
|
self.q_proj = nn.Linear(self.config.hidden_size, self.all_head_size, bias=self.config.attention_bias)
|
||||||
|
self.k_proj = nn.Linear(self.config.hidden_size, self.all_head_size, bias=self.config.attention_bias)
|
||||||
|
self.v_proj = nn.Linear(self.config.hidden_size, self.all_head_size, bias=self.config.attention_bias)
|
||||||
|
|
||||||
|
self.attention_type = config.layer_types[layer_idx]
|
||||||
|
if self.attention_type == "sliding_attention":
|
||||||
|
# NOTE: to match ModernBERT, we need to divide by 2 and add one for inclusive
|
||||||
|
self.local_attention = (config.local_attention // 2 + 1, config.local_attention // 2 + 1)
|
||||||
|
else:
|
||||||
|
self.local_attention = (-1, -1)
|
||||||
|
|
||||||
|
self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
|
||||||
|
self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
output_attentions: Optional[bool] = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
|
||||||
|
batch_size, seq_len, _ = hidden_states.shape
|
||||||
|
|
||||||
|
query = self.q_proj(hidden_states)
|
||||||
|
key = self.k_proj(hidden_states)
|
||||||
|
value = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
# Reshape to [batch_size, seq_len, num_heads, head_dim]
|
||||||
|
query = query.view(batch_size, seq_len, self.num_heads, self.head_dim)
|
||||||
|
key = key.view(batch_size, seq_len, self.num_heads, self.head_dim)
|
||||||
|
value = value.view(batch_size, seq_len, self.num_heads, self.head_dim)
|
||||||
|
|
||||||
|
# Transpose to [batch_size, num_heads, seq_len, head_dim]
|
||||||
|
query = query.transpose(1, 2)
|
||||||
|
key = key.transpose(1, 2)
|
||||||
|
value = value.transpose(1, 2)
|
||||||
|
|
||||||
|
# Apply rotary embeddings (passed from model level)
|
||||||
|
cos, sin = position_embeddings
|
||||||
|
query, key = apply_rotary_pos_emb(query, key, cos, sin)
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||||
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||||
|
key, value = past_key_value.update(key, value, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
|
attention_interface = eager_attention_forward
|
||||||
|
if self.config._attn_implementation != "eager":
|
||||||
|
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||||
|
|
||||||
|
# Pass sliding window parameter for sliding attention layers
|
||||||
|
sliding_window_param = self.local_attention[0] if self.local_attention[0] != -1 else None
|
||||||
|
|
||||||
|
attn_outputs = attention_interface(
|
||||||
|
self,
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
attention_mask,
|
||||||
|
dropout=0.0 if not self.training else self.attention_dropout,
|
||||||
|
sliding_window=sliding_window_param,
|
||||||
|
is_causal=True,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_outputs[0]
|
||||||
|
attn_weights = attn_outputs[1] if output_attentions and len(attn_outputs) > 1 else None
|
||||||
|
|
||||||
|
# Reshape to [batch_size, seq_len, hidden_size] - this handles both eager and FA2 outputs
|
||||||
|
input_shape = hidden_states.shape[:-1]
|
||||||
|
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||||
|
|
||||||
|
# Apply output projection
|
||||||
|
hidden_states = self.out_drop(self.Wo(attn_output))
|
||||||
|
|
||||||
|
outputs = (hidden_states,)
|
||||||
|
if output_attentions:
|
||||||
|
outputs += (attn_weights,)
|
||||||
|
if past_key_value is not None:
|
||||||
|
outputs += (past_key_value,)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class ModernBertDecoderLayer(GradientCheckpointingLayer):
|
||||||
|
def __init__(self, config: ModernBertDecoderConfig, layer_idx: Optional[int] = None):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
self.attention_type = config.layer_types[layer_idx]
|
||||||
|
|
||||||
|
if layer_idx == 0:
|
||||||
|
self.attn_norm = nn.Identity()
|
||||||
|
else:
|
||||||
|
self.attn_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
|
||||||
|
self.attn = ModernBertDecoderAttention(config=config, layer_idx=layer_idx)
|
||||||
|
self.mlp_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
|
||||||
|
self.mlp = ModernBertMLP(config)
|
||||||
|
|
||||||
|
@torch.compile(dynamic=True)
|
||||||
|
def compiled_mlp(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.mlp(self.mlp_norm(hidden_states))
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
output_attentions: Optional[bool] = False,
|
||||||
|
use_cache: Optional[bool] = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.attn_norm(hidden_states)
|
||||||
|
|
||||||
|
# Self Attention
|
||||||
|
attn_outputs = self.attn(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
position_embeddings=position_embeddings,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
cache_position=cache_position,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = attn_outputs[0]
|
||||||
|
|
||||||
|
# Add residual connection
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
# MLP
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.mlp_norm(hidden_states)
|
||||||
|
mlp_output = self.compiled_mlp(hidden_states) if self.config.reference_compile else self.mlp(hidden_states)
|
||||||
|
hidden_states = residual + mlp_output
|
||||||
|
|
||||||
|
outputs = (hidden_states,)
|
||||||
|
if len(attn_outputs) > 1:
|
||||||
|
outputs += attn_outputs[1:]
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
@auto_docstring
|
||||||
|
class ModernBertDecoderPreTrainedModel(ModernBertPreTrainedModel):
|
||||||
|
config_class = ModernBertDecoderConfig
|
||||||
|
base_model_prefix = "model"
|
||||||
|
_skip_keys_device_placement = ["past_key_values"]
|
||||||
|
_no_split_modules = ["ModernBertDecoderLayer"]
|
||||||
|
_supports_flash_attn_2 = True
|
||||||
|
_supports_sdpa = False
|
||||||
|
_supports_gradient_checkpointing = True
|
||||||
|
_supports_cache_class = True
|
||||||
|
_supports_quantized_cache = True
|
||||||
|
_supports_static_cache = False
|
||||||
|
_supports_attention_backend = True
|
||||||
|
|
||||||
|
def _init_weights(self, module: nn.Module):
|
||||||
|
cutoff_factor = self.config.initializer_cutoff_factor
|
||||||
|
if cutoff_factor is None:
|
||||||
|
cutoff_factor = 3
|
||||||
|
|
||||||
|
def init_weight(module: nn.Module, std: float):
|
||||||
|
nn.init.trunc_normal_(
|
||||||
|
module.weight,
|
||||||
|
mean=0.0,
|
||||||
|
std=std,
|
||||||
|
a=-cutoff_factor * std,
|
||||||
|
b=cutoff_factor * std,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(module, nn.Linear):
|
||||||
|
if module.bias is not None:
|
||||||
|
nn.init.zeros_(module.bias)
|
||||||
|
|
||||||
|
stds = {
|
||||||
|
"in": self.config.initializer_range,
|
||||||
|
"out": self.config.initializer_range / math.sqrt(2.0 * self.config.num_hidden_layers),
|
||||||
|
"embedding": self.config.initializer_range,
|
||||||
|
"final_out": self.config.hidden_size**-0.5,
|
||||||
|
}
|
||||||
|
|
||||||
|
if isinstance(module, ModernBertEmbeddings):
|
||||||
|
init_weight(module.tok_embeddings, stds["embedding"])
|
||||||
|
elif isinstance(module, ModernBertMLP):
|
||||||
|
init_weight(module.Wi, stds["in"])
|
||||||
|
init_weight(module.Wo, stds["out"])
|
||||||
|
elif isinstance(module, ModernBertDecoderAttention):
|
||||||
|
init_weight(module.q_proj, stds["in"])
|
||||||
|
init_weight(module.k_proj, stds["in"])
|
||||||
|
init_weight(module.v_proj, stds["in"])
|
||||||
|
init_weight(module.Wo, stds["out"])
|
||||||
|
elif isinstance(module, ModernBertPredictionHead):
|
||||||
|
init_weight(module.dense, stds["out"])
|
||||||
|
elif isinstance(module, ModernBertDecoderForSequenceClassification):
|
||||||
|
init_weight(module.classifier, stds["final_out"])
|
||||||
|
elif isinstance(module, ModernBertDecoderForCausalLM):
|
||||||
|
init_weight(module.decoder, stds["out"])
|
||||||
|
elif isinstance(module, nn.LayerNorm):
|
||||||
|
module.weight.data.fill_(1.0)
|
||||||
|
if module.bias is not None:
|
||||||
|
module.bias.data.zero_()
|
||||||
|
|
||||||
|
def resize_token_embeddings(self, *args, **kwargs):
|
||||||
|
model_embeds = super().resize_token_embeddings(*args, **kwargs)
|
||||||
|
|
||||||
|
if self.config.reference_compile in {True, None}:
|
||||||
|
if self.config.reference_compile:
|
||||||
|
logger.warning_once(
|
||||||
|
"Resizing token embeddings with `torch.compile` is not supported. Falling back to non-compiled mode."
|
||||||
|
)
|
||||||
|
self.config.reference_compile = False
|
||||||
|
|
||||||
|
return model_embeds
|
||||||
|
|
||||||
|
|
||||||
|
@auto_docstring
|
||||||
|
class ModernBertDecoderModel(ModernBertDecoderPreTrainedModel):
|
||||||
|
def __init__(self, config: ModernBertDecoderConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
self.config = config
|
||||||
|
self.embeddings = ModernBertEmbeddings(config)
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[ModernBertDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||||
|
)
|
||||||
|
self.final_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
self.global_rotary_emb = ModernBertRotaryEmbedding(
|
||||||
|
config=config, dim=config.hidden_size // config.num_attention_heads, base=config.global_rope_theta
|
||||||
|
)
|
||||||
|
if config.local_rope_theta is not None:
|
||||||
|
self.local_rotary_emb = ModernBertRotaryEmbedding(
|
||||||
|
config=config, dim=config.hidden_size // config.num_attention_heads, base=config.local_rope_theta
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.local_rotary_emb = self.global_rotary_emb
|
||||||
|
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.embeddings.tok_embeddings
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.embeddings.tok_embeddings = value
|
||||||
|
|
||||||
|
@auto_docstring
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[Cache] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Union[tuple[torch.Tensor, ...], BaseModelOutputWithPast]:
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
if (input_ids is None) == (inputs_embeds is None):
|
||||||
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
if input_ids is not None:
|
||||||
|
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
||||||
|
batch_size, seq_length = input_ids.shape[:2]
|
||||||
|
else:
|
||||||
|
batch_size, seq_length = inputs_embeds.shape[:2]
|
||||||
|
|
||||||
|
# Handle past_key_values and cache setup
|
||||||
|
if use_cache and past_key_values is None and not self.training:
|
||||||
|
past_key_values = DynamicCache()
|
||||||
|
|
||||||
|
if cache_position is None:
|
||||||
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
|
cache_position = torch.arange(
|
||||||
|
past_seen_tokens,
|
||||||
|
past_seen_tokens + seq_length,
|
||||||
|
device=input_ids.device if input_ids is not None else inputs_embeds.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = cache_position.unsqueeze(0).expand(batch_size, -1)
|
||||||
|
|
||||||
|
# Calculate embeddings
|
||||||
|
hidden_states = self.embeddings(input_ids=input_ids, inputs_embeds=inputs_embeds)
|
||||||
|
|
||||||
|
# It may already have been prepared by e.g. `generate`
|
||||||
|
if not isinstance(causal_mask_mapping := attention_mask, dict):
|
||||||
|
# Prepare mask arguments
|
||||||
|
mask_kwargs = {
|
||||||
|
"config": self.config,
|
||||||
|
"input_embeds": hidden_states,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"cache_position": cache_position,
|
||||||
|
"past_key_values": past_key_values,
|
||||||
|
}
|
||||||
|
|
||||||
|
causal_mask_mapping = {
|
||||||
|
"full_attention": create_causal_mask(**mask_kwargs),
|
||||||
|
}
|
||||||
|
|
||||||
|
if any(layer_type == "sliding_attention" for layer_type in self.config.layer_types):
|
||||||
|
# NOTE: sliding window numbers matches ModernBERT but is only half of it
|
||||||
|
# +1 is because it is inclusive of that number
|
||||||
|
if hasattr(self.config, "local_attention") and self.config.local_attention is not None:
|
||||||
|
self.config.sliding_window = self.config.local_attention // 2 + 1
|
||||||
|
|
||||||
|
causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
|
||||||
|
else:
|
||||||
|
causal_mask_mapping["sliding_attention"] = causal_mask_mapping["full_attention"]
|
||||||
|
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
all_self_attentions = () if output_attentions else None
|
||||||
|
next_decoder_cache = past_key_values if use_cache else None
|
||||||
|
|
||||||
|
for idx, decoder_layer in enumerate(self.layers):
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
# Get the appropriate rotary embedding for this layer
|
||||||
|
if decoder_layer.attention_type == "sliding_attention":
|
||||||
|
rotary_emb = self.local_rotary_emb
|
||||||
|
else:
|
||||||
|
rotary_emb = self.global_rotary_emb
|
||||||
|
cos, sin = rotary_emb(hidden_states, position_ids)
|
||||||
|
position_embeddings = (cos, sin)
|
||||||
|
|
||||||
|
# Use the appropriate mask for this layer's attention type
|
||||||
|
layer_attention_mask = causal_mask_mapping[decoder_layer.attention_type]
|
||||||
|
|
||||||
|
layer_outputs = decoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
position_embeddings=position_embeddings,
|
||||||
|
attention_mask=layer_attention_mask,
|
||||||
|
past_key_value=next_decoder_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
cache_position=cache_position,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||||
|
|
||||||
|
hidden_states = self.final_norm(hidden_states)
|
||||||
|
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return tuple(
|
||||||
|
v for v in [hidden_states, next_decoder_cache, all_hidden_states, all_self_attentions] if v is not None
|
||||||
|
)
|
||||||
|
return BaseModelOutputWithPast(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=next_decoder_cache,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@auto_docstring(
|
||||||
|
custom_intro="""
|
||||||
|
The ModernBert Decoder Model with a language modeling head on top for causal language modeling (CLM).
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
class ModernBertDecoderForCausalLM(ModernBertDecoderPreTrainedModel, GenerationMixin):
|
||||||
|
_tied_weights_keys = ["decoder.weight"]
|
||||||
|
|
||||||
|
def __init__(self, config: ModernBertDecoderConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
self.config = config
|
||||||
|
self.model = ModernBertDecoderModel(config)
|
||||||
|
self.lm_head = ModernBertPredictionHead(config)
|
||||||
|
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=config.decoder_bias)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.model.embeddings.tok_embeddings
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.model.embeddings.tok_embeddings = value
|
||||||
|
|
||||||
|
def get_output_embeddings(self):
|
||||||
|
return self.decoder
|
||||||
|
|
||||||
|
def set_output_embeddings(self, new_embeddings):
|
||||||
|
self.decoder = new_embeddings
|
||||||
|
|
||||||
|
def set_decoder(self, decoder):
|
||||||
|
self.model = decoder
|
||||||
|
|
||||||
|
def get_decoder(self):
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
@torch.compile(dynamic=True)
|
||||||
|
def compiled_head(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.decoder(self.lm_head(hidden_states))
|
||||||
|
|
||||||
|
@auto_docstring
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[Cache] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Union[tuple, CausalLMOutputWithPast]:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[`~modeling_outputs.CausalLMOutputWithPast`] or `tuple(torch.FloatTensor)`: A
|
||||||
|
[`~modeling_outputs.CausalLMOutputWithPast`] or a tuple of `torch.FloatTensor` (if `return_dict=False`)
|
||||||
|
comprising various elements depending on the configuration and inputs.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoTokenizer, ModernBertDecoderForCausalLM
|
||||||
|
|
||||||
|
>>> model = ModernBertDecoderForCausalLM.from_pretrained("blab-jhu/test-32m-dec")
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained("blab-jhu/test-32m-dec")
|
||||||
|
|
||||||
|
>>> prompt = "The capital of France is"
|
||||||
|
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||||
|
|
||||||
|
>>> # Generate
|
||||||
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=1)
|
||||||
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
|
"The capital of France is Paris"
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
logits = (
|
||||||
|
self.compiled_head(hidden_states)
|
||||||
|
if self.config.reference_compile
|
||||||
|
else self.decoder(self.lm_head(hidden_states))
|
||||||
|
)
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
# Shift so that tokens < n predict n
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
# Flatten the tokens
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||||
|
shift_labels = shift_labels.view(-1)
|
||||||
|
# Enable model parallelism
|
||||||
|
shift_labels = shift_labels.to(shift_logits.device)
|
||||||
|
loss = loss_fct(shift_logits, shift_labels)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return CausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _reorder_cache(past_key_values, beam_idx):
|
||||||
|
reordered_past = ()
|
||||||
|
for layer_past in past_key_values:
|
||||||
|
reordered_past += (
|
||||||
|
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||||
|
)
|
||||||
|
return reordered_past
|
||||||
|
|
||||||
|
|
||||||
|
@auto_docstring(
|
||||||
|
custom_intro="""
|
||||||
|
The ModernBert Decoder Model with a sequence classification head on top (linear layer).
|
||||||
|
|
||||||
|
[`ModernBertDecoderForSequenceClassification`] uses the last token in order to do the classification, as other causal models
|
||||||
|
(e.g. GPT-1, GPT-2) do.
|
||||||
|
|
||||||
|
Since it does classification on the last token, it requires to know the position of the last token. If a
|
||||||
|
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
|
||||||
|
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
|
||||||
|
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
||||||
|
each row of the batch).
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
class ModernBertDecoderForSequenceClassification(ModernBertDecoderPreTrainedModel):
|
||||||
|
def __init__(self, config: ModernBertDecoderConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
self.num_labels = config.num_labels
|
||||||
|
self.model = ModernBertDecoderModel(config)
|
||||||
|
|
||||||
|
self.head = ModernBertPredictionHead(config)
|
||||||
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels, bias=config.classifier_bias)
|
||||||
|
self.drop = torch.nn.Dropout(config.classifier_dropout)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.model.embeddings.tok_embeddings
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.model.embeddings.tok_embeddings = value
|
||||||
|
|
||||||
|
def get_output_embeddings(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def set_output_embeddings(self, new_embeddings):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def set_decoder(self, decoder):
|
||||||
|
self.model = decoder
|
||||||
|
|
||||||
|
def get_decoder(self):
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
@auto_docstring(checkpoint="blab-jhu/test-32m-dec")
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[Cache] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Union[tuple, SequenceClassifierOutputWithPast]:
|
||||||
|
r"""
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||||
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||||
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||||
|
"""
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
transformer_outputs = self.model(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
hidden_states = transformer_outputs[0]
|
||||||
|
hidden_states = self.drop(self.head(hidden_states))
|
||||||
|
logits = self.classifier(hidden_states)
|
||||||
|
|
||||||
|
if input_ids is not None:
|
||||||
|
batch_size, sequence_length = input_ids.shape[:2]
|
||||||
|
else:
|
||||||
|
batch_size, sequence_length = inputs_embeds.shape[:2]
|
||||||
|
|
||||||
|
if self.config.pad_token_id is None and batch_size != 1:
|
||||||
|
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||||
|
if self.config.pad_token_id is None:
|
||||||
|
last_non_pad_token = -1
|
||||||
|
elif input_ids is not None:
|
||||||
|
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
|
||||||
|
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
|
||||||
|
token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
|
||||||
|
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
|
||||||
|
else:
|
||||||
|
last_non_pad_token = -1
|
||||||
|
logger.warning_once(
|
||||||
|
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||||
|
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||||
|
)
|
||||||
|
|
||||||
|
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
if self.config.problem_type is None:
|
||||||
|
if self.num_labels == 1:
|
||||||
|
self.config.problem_type = "regression"
|
||||||
|
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||||
|
self.config.problem_type = "single_label_classification"
|
||||||
|
else:
|
||||||
|
self.config.problem_type = "multi_label_classification"
|
||||||
|
|
||||||
|
if self.config.problem_type == "regression":
|
||||||
|
loss_fct = MSELoss()
|
||||||
|
if self.num_labels == 1:
|
||||||
|
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
||||||
|
else:
|
||||||
|
loss = loss_fct(pooled_logits, labels)
|
||||||
|
elif self.config.problem_type == "single_label_classification":
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
|
elif self.config.problem_type == "multi_label_classification":
|
||||||
|
loss_fct = BCEWithLogitsLoss()
|
||||||
|
loss = loss_fct(pooled_logits, labels)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (pooled_logits,) + transformer_outputs[1:]
|
||||||
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
|
return SequenceClassifierOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=pooled_logits,
|
||||||
|
past_key_values=transformer_outputs.past_key_values,
|
||||||
|
hidden_states=transformer_outputs.hidden_states,
|
||||||
|
attentions=transformer_outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ModernBertDecoderModel",
|
||||||
|
"ModernBertDecoderPreTrainedModel",
|
||||||
|
"ModernBertDecoderForCausalLM",
|
||||||
|
"ModernBertDecoderForSequenceClassification",
|
||||||
|
]
|
@ -0,0 +1,968 @@
|
|||||||
|
# Copyright 2024 Answer.AI, LightOn, and contributors, and the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
|
from ...cache_utils import Cache, DynamicCache
|
||||||
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
from ...generation import GenerationMixin
|
||||||
|
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
||||||
|
from ...modeling_layers import GradientCheckpointingLayer
|
||||||
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
||||||
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||||
|
from ...models.modernbert.modeling_modernbert import (
|
||||||
|
ModernBertEmbeddings,
|
||||||
|
ModernBertMLP,
|
||||||
|
ModernBertPredictionHead,
|
||||||
|
ModernBertPreTrainedModel,
|
||||||
|
ModernBertRotaryEmbedding,
|
||||||
|
apply_rotary_pos_emb,
|
||||||
|
)
|
||||||
|
from ...utils import auto_docstring, logging
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ModernBertDecoderConfig(PretrainedConfig):
|
||||||
|
r"""
|
||||||
|
This is the configuration class to store the configuration of a [`ModernBertDecoderModel`]. It is used to instantiate a ModernBert
|
||||||
|
decoder model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||||
|
defaults will yield a similar configuration to that of the ModernBERT-base decoder.
|
||||||
|
e.g. [blab-jhu/test-32m-dec](https://huggingface.co/blab-jhu/test-32m-dec)
|
||||||
|
|
||||||
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vocab_size (`int`, *optional*, defaults to 50368):
|
||||||
|
Vocabulary size of the ModernBert decoder model. Defines the number of different tokens that can be represented by the
|
||||||
|
`inputs_ids` passed when calling [`ModernBertDecoderModel`]
|
||||||
|
hidden_size (`int`, *optional*, defaults to 768):
|
||||||
|
Dimension of the hidden representations.
|
||||||
|
intermediate_size (`int`, *optional*, defaults to 1152):
|
||||||
|
Dimension of the MLP representations.
|
||||||
|
num_hidden_layers (`int`, *optional*, defaults to 22):
|
||||||
|
Number of hidden layers in the Transformer decoder.
|
||||||
|
num_attention_heads (`int`, *optional*, defaults to 12):
|
||||||
|
Number of attention heads for each attention layer in the Transformer decoder.
|
||||||
|
hidden_activation (`str` or `function`, *optional*, defaults to `"gelu"`):
|
||||||
|
The non-linear activation function (function or string) in the decoder. Will default to `"gelu"`
|
||||||
|
if not specified.
|
||||||
|
max_position_embeddings (`int`, *optional*, defaults to 8192):
|
||||||
|
The maximum sequence length that this model might ever be used with.
|
||||||
|
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||||
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||||
|
initializer_cutoff_factor (`float`, *optional*, defaults to 2.0):
|
||||||
|
The cutoff factor for the truncated_normal_initializer for initializing all weight matrices.
|
||||||
|
norm_eps (`float`, *optional*, defaults to 1e-05):
|
||||||
|
The epsilon used by the rms normalization layers.
|
||||||
|
norm_bias (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to use bias in the normalization layers.
|
||||||
|
pad_token_id (`int`, *optional*, defaults to 50283):
|
||||||
|
Padding token id.
|
||||||
|
eos_token_id (`int`, *optional*, defaults to 50282):
|
||||||
|
End of stream token id.
|
||||||
|
bos_token_id (`int`, *optional*, defaults to 50281):
|
||||||
|
Beginning of stream token id.
|
||||||
|
cls_token_id (`int`, *optional*, defaults to 50281):
|
||||||
|
Classification token id.
|
||||||
|
sep_token_id (`int`, *optional*, defaults to 50282):
|
||||||
|
Separation token id.
|
||||||
|
global_rope_theta (`float`, *optional*, defaults to 160000.0):
|
||||||
|
The base period of the global RoPE embeddings.
|
||||||
|
attention_bias (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||||
|
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||||
|
The dropout ratio for the attention probabilities.
|
||||||
|
embedding_dropout (`float`, *optional*, defaults to 0.0):
|
||||||
|
The dropout ratio for the embeddings.
|
||||||
|
mlp_bias (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to use bias in the MLP layers.
|
||||||
|
mlp_dropout (`float`, *optional*, defaults to 0.0):
|
||||||
|
The dropout ratio for the MLP layers.
|
||||||
|
decoder_bias (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to use bias in the decoder layers.
|
||||||
|
classifier_dropout (`float`, *optional*, defaults to 0.0):
|
||||||
|
The dropout ratio for the classifier.
|
||||||
|
classifier_bias (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to use bias in the classifier.
|
||||||
|
classifier_activation (`str`, *optional*, defaults to `"gelu"`):
|
||||||
|
The activation function for the classifier.
|
||||||
|
deterministic_flash_attn (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to use deterministic flash attention. If `False`, inference will be faster but not deterministic.
|
||||||
|
reference_compile (`bool`, *optional*):
|
||||||
|
Whether to compile the layers of the model which were compiled during pretraining. If `None`, then parts of
|
||||||
|
the model will be compiled if 1) `triton` is installed, 2) the model is not on MPS, 3) the model is not
|
||||||
|
shared between devices, and 4) the model is not resized after initialization. If `True`, then the model may
|
||||||
|
be faster in some scenarios.
|
||||||
|
use_cache (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||||
|
relevant if `config.is_decoder=True`.
|
||||||
|
local_attention (`int`, *optional*, defaults to 128):
|
||||||
|
The sliding window size for local attention. Only used for layers that use local attention. Note that for
|
||||||
|
the decoder to match ModernBERT this is actually half of the sliding window size, so 128 => 64.
|
||||||
|
global_attn_every_n_layers (`int`, *optional*, defaults to 3):
|
||||||
|
Every `global_attn_every_n_layers` layers will use global attention instead of local attention.
|
||||||
|
local_rope_theta (`float`, *optional*):
|
||||||
|
The base period of the local RoPE embeddings. If not specified, uses the same value as `global_rope_theta`.
|
||||||
|
layer_types (`list`, *optional*):
|
||||||
|
List of layer types, one for each layer. If not specified, will be automatically generated based on
|
||||||
|
`global_attn_every_n_layers`. Should contain "full_attention" or "sliding_window".
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import ModernBertDecoderModel, ModernBertDecoderConfig
|
||||||
|
|
||||||
|
>>> # Initializing a ModernBert decoder style configuration
|
||||||
|
>>> configuration = ModernBertDecoderConfig()
|
||||||
|
|
||||||
|
>>> # Initializing a model from the modernbert-base decoder style configuration
|
||||||
|
>>> model = ModernBertDecoderModel(configuration)
|
||||||
|
|
||||||
|
>>> # Accessing the model configuration
|
||||||
|
>>> configuration = model.config
|
||||||
|
```"""
|
||||||
|
|
||||||
|
model_type = "modernbert-decoder"
|
||||||
|
keys_to_ignore_at_inference = ["past_key_values"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=50368,
|
||||||
|
hidden_size=768,
|
||||||
|
intermediate_size=1152,
|
||||||
|
num_hidden_layers=22,
|
||||||
|
num_attention_heads=12,
|
||||||
|
hidden_activation="gelu",
|
||||||
|
max_position_embeddings=8192,
|
||||||
|
initializer_range=0.02,
|
||||||
|
initializer_cutoff_factor=2.0,
|
||||||
|
norm_eps=1e-5,
|
||||||
|
norm_bias=False,
|
||||||
|
pad_token_id=50283,
|
||||||
|
eos_token_id=50282,
|
||||||
|
bos_token_id=50281,
|
||||||
|
cls_token_id=50281,
|
||||||
|
sep_token_id=50282,
|
||||||
|
global_rope_theta=160000.0,
|
||||||
|
attention_bias=False,
|
||||||
|
attention_dropout=0.0,
|
||||||
|
embedding_dropout=0.0,
|
||||||
|
mlp_bias=False,
|
||||||
|
mlp_dropout=0.0,
|
||||||
|
decoder_bias=True,
|
||||||
|
classifier_dropout=0.0,
|
||||||
|
classifier_bias=False,
|
||||||
|
classifier_activation="gelu",
|
||||||
|
deterministic_flash_attn=False,
|
||||||
|
reference_compile=None,
|
||||||
|
use_cache=True,
|
||||||
|
local_attention=128,
|
||||||
|
global_attn_every_n_layers=3,
|
||||||
|
local_rope_theta=None,
|
||||||
|
layer_types=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
cls_token_id=cls_token_id,
|
||||||
|
sep_token_id=sep_token_id,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.initializer_cutoff_factor = initializer_cutoff_factor
|
||||||
|
self.norm_eps = norm_eps
|
||||||
|
self.norm_bias = norm_bias
|
||||||
|
self.global_rope_theta = global_rope_theta
|
||||||
|
self.attention_bias = attention_bias
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
self.hidden_activation = hidden_activation
|
||||||
|
self.embedding_dropout = embedding_dropout
|
||||||
|
self.mlp_bias = mlp_bias
|
||||||
|
self.mlp_dropout = mlp_dropout
|
||||||
|
self.decoder_bias = decoder_bias
|
||||||
|
self.classifier_dropout = classifier_dropout
|
||||||
|
self.classifier_bias = classifier_bias
|
||||||
|
self.classifier_activation = classifier_activation
|
||||||
|
self.deterministic_flash_attn = deterministic_flash_attn
|
||||||
|
self.reference_compile = reference_compile
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.local_attention = local_attention
|
||||||
|
self.global_attn_every_n_layers = global_attn_every_n_layers
|
||||||
|
self.local_rope_theta = local_rope_theta
|
||||||
|
|
||||||
|
# Set up layer_types for standardized layer type detection
|
||||||
|
self.layer_types = layer_types
|
||||||
|
if self.layer_types is None:
|
||||||
|
# Create layer_types based on the alternating pattern
|
||||||
|
self.layer_types = []
|
||||||
|
for layer_id in range(num_hidden_layers):
|
||||||
|
if layer_id % global_attn_every_n_layers != 0:
|
||||||
|
self.layer_types.append("sliding_attention")
|
||||||
|
else:
|
||||||
|
self.layer_types.append("full_attention")
|
||||||
|
|
||||||
|
self.sliding_window = local_attention
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
output = super().to_dict()
|
||||||
|
output.pop("reference_compile", None)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def eager_attention_forward(
|
||||||
|
module: "ModernBertDecoderAttention",
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor],
|
||||||
|
dropout: float = 0.0,
|
||||||
|
scaling: Optional[float] = None,
|
||||||
|
sliding_window: Optional[int] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
"""A simple eager attention implementation for ModernBERT decoder."""
|
||||||
|
if scaling is None:
|
||||||
|
scaling = module.head_dim**-0.5
|
||||||
|
|
||||||
|
# Compute attention scores
|
||||||
|
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||||
|
|
||||||
|
# Use the pre-computed attention mask
|
||||||
|
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||||
|
attn_weights = attn_weights + causal_mask
|
||||||
|
|
||||||
|
# upcast attention to fp32
|
||||||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||||
|
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||||
|
attn_output = torch.matmul(attn_weights, value)
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
|
||||||
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
||||||
|
class ModernBertDecoderAttention(nn.Module):
|
||||||
|
"""Performs causal multi-headed self attention for ModernBERT decoder.
|
||||||
|
|
||||||
|
It supports both local attention (sliding window) and global attention patterns.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: ModernBertDecoderConfig, layer_idx: Optional[int] = None):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
self.is_causal = True
|
||||||
|
|
||||||
|
if config.hidden_size % config.num_attention_heads != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention heads ({config.num_attention_heads})"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.attention_dropout = config.attention_dropout
|
||||||
|
self.deterministic_flash_attn = config.deterministic_flash_attn
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.head_dim = config.hidden_size // config.num_attention_heads
|
||||||
|
self.all_head_size = self.head_dim * self.num_heads
|
||||||
|
self.scaling = self.head_dim**-0.5
|
||||||
|
|
||||||
|
# NOTE: this is different than ModernBERT (separated QKV) so be sure to adapt to this
|
||||||
|
self.q_proj = nn.Linear(self.config.hidden_size, self.all_head_size, bias=self.config.attention_bias)
|
||||||
|
self.k_proj = nn.Linear(self.config.hidden_size, self.all_head_size, bias=self.config.attention_bias)
|
||||||
|
self.v_proj = nn.Linear(self.config.hidden_size, self.all_head_size, bias=self.config.attention_bias)
|
||||||
|
|
||||||
|
self.attention_type = config.layer_types[layer_idx]
|
||||||
|
if self.attention_type == "sliding_attention":
|
||||||
|
# NOTE: to match ModernBERT, we need to divide by 2 and add one for inclusive
|
||||||
|
self.local_attention = (config.local_attention // 2 + 1, config.local_attention // 2 + 1)
|
||||||
|
else:
|
||||||
|
self.local_attention = (-1, -1)
|
||||||
|
|
||||||
|
self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
|
||||||
|
self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
output_attentions: Optional[bool] = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
|
||||||
|
batch_size, seq_len, _ = hidden_states.shape
|
||||||
|
|
||||||
|
query = self.q_proj(hidden_states)
|
||||||
|
key = self.k_proj(hidden_states)
|
||||||
|
value = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
# Reshape to [batch_size, seq_len, num_heads, head_dim]
|
||||||
|
query = query.view(batch_size, seq_len, self.num_heads, self.head_dim)
|
||||||
|
key = key.view(batch_size, seq_len, self.num_heads, self.head_dim)
|
||||||
|
value = value.view(batch_size, seq_len, self.num_heads, self.head_dim)
|
||||||
|
|
||||||
|
# Transpose to [batch_size, num_heads, seq_len, head_dim]
|
||||||
|
query = query.transpose(1, 2)
|
||||||
|
key = key.transpose(1, 2)
|
||||||
|
value = value.transpose(1, 2)
|
||||||
|
|
||||||
|
# Apply rotary embeddings (passed from model level)
|
||||||
|
cos, sin = position_embeddings
|
||||||
|
query, key = apply_rotary_pos_emb(query, key, cos, sin)
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||||
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||||
|
key, value = past_key_value.update(key, value, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
|
attention_interface = eager_attention_forward
|
||||||
|
if self.config._attn_implementation != "eager":
|
||||||
|
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||||
|
|
||||||
|
# Pass sliding window parameter for sliding attention layers
|
||||||
|
sliding_window_param = self.local_attention[0] if self.local_attention[0] != -1 else None
|
||||||
|
|
||||||
|
attn_outputs = attention_interface(
|
||||||
|
self,
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
attention_mask,
|
||||||
|
dropout=0.0 if not self.training else self.attention_dropout,
|
||||||
|
sliding_window=sliding_window_param,
|
||||||
|
is_causal=True,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_outputs[0]
|
||||||
|
attn_weights = attn_outputs[1] if output_attentions and len(attn_outputs) > 1 else None
|
||||||
|
|
||||||
|
# Reshape to [batch_size, seq_len, hidden_size] - this handles both eager and FA2 outputs
|
||||||
|
input_shape = hidden_states.shape[:-1]
|
||||||
|
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||||
|
|
||||||
|
# Apply output projection
|
||||||
|
hidden_states = self.out_drop(self.Wo(attn_output))
|
||||||
|
|
||||||
|
outputs = (hidden_states,)
|
||||||
|
if output_attentions:
|
||||||
|
outputs += (attn_weights,)
|
||||||
|
if past_key_value is not None:
|
||||||
|
outputs += (past_key_value,)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class ModernBertDecoderLayer(GradientCheckpointingLayer):
|
||||||
|
def __init__(self, config: ModernBertDecoderConfig, layer_idx: Optional[int] = None):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
self.attention_type = config.layer_types[layer_idx]
|
||||||
|
|
||||||
|
if layer_idx == 0:
|
||||||
|
self.attn_norm = nn.Identity()
|
||||||
|
else:
|
||||||
|
self.attn_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
|
||||||
|
self.attn = ModernBertDecoderAttention(config=config, layer_idx=layer_idx)
|
||||||
|
self.mlp_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
|
||||||
|
self.mlp = ModernBertMLP(config)
|
||||||
|
|
||||||
|
@torch.compile(dynamic=True)
|
||||||
|
def compiled_mlp(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.mlp(self.mlp_norm(hidden_states))
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
output_attentions: Optional[bool] = False,
|
||||||
|
use_cache: Optional[bool] = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.attn_norm(hidden_states)
|
||||||
|
|
||||||
|
# Self Attention
|
||||||
|
attn_outputs = self.attn(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
position_embeddings=position_embeddings,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
cache_position=cache_position,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = attn_outputs[0]
|
||||||
|
|
||||||
|
# Add residual connection
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
# MLP
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.mlp_norm(hidden_states)
|
||||||
|
mlp_output = self.compiled_mlp(hidden_states) if self.config.reference_compile else self.mlp(hidden_states)
|
||||||
|
hidden_states = residual + mlp_output
|
||||||
|
|
||||||
|
outputs = (hidden_states,)
|
||||||
|
if len(attn_outputs) > 1:
|
||||||
|
outputs += attn_outputs[1:]
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
@auto_docstring
|
||||||
|
class ModernBertDecoderPreTrainedModel(ModernBertPreTrainedModel):
|
||||||
|
config_class = ModernBertDecoderConfig
|
||||||
|
base_model_prefix = "model"
|
||||||
|
_skip_keys_device_placement = ["past_key_values"]
|
||||||
|
_no_split_modules = ["ModernBertDecoderLayer"]
|
||||||
|
_supports_flash_attn_2 = True
|
||||||
|
_supports_sdpa = False
|
||||||
|
_supports_gradient_checkpointing = True
|
||||||
|
_supports_cache_class = True
|
||||||
|
_supports_quantized_cache = True
|
||||||
|
_supports_static_cache = False
|
||||||
|
_supports_attention_backend = True
|
||||||
|
|
||||||
|
def _init_weights(self, module: nn.Module):
|
||||||
|
cutoff_factor = self.config.initializer_cutoff_factor
|
||||||
|
if cutoff_factor is None:
|
||||||
|
cutoff_factor = 3
|
||||||
|
|
||||||
|
def init_weight(module: nn.Module, std: float):
|
||||||
|
nn.init.trunc_normal_(
|
||||||
|
module.weight,
|
||||||
|
mean=0.0,
|
||||||
|
std=std,
|
||||||
|
a=-cutoff_factor * std,
|
||||||
|
b=cutoff_factor * std,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(module, nn.Linear):
|
||||||
|
if module.bias is not None:
|
||||||
|
nn.init.zeros_(module.bias)
|
||||||
|
|
||||||
|
stds = {
|
||||||
|
"in": self.config.initializer_range,
|
||||||
|
"out": self.config.initializer_range / math.sqrt(2.0 * self.config.num_hidden_layers),
|
||||||
|
"embedding": self.config.initializer_range,
|
||||||
|
"final_out": self.config.hidden_size**-0.5,
|
||||||
|
}
|
||||||
|
|
||||||
|
if isinstance(module, ModernBertEmbeddings):
|
||||||
|
init_weight(module.tok_embeddings, stds["embedding"])
|
||||||
|
elif isinstance(module, ModernBertMLP):
|
||||||
|
init_weight(module.Wi, stds["in"])
|
||||||
|
init_weight(module.Wo, stds["out"])
|
||||||
|
elif isinstance(module, ModernBertDecoderAttention):
|
||||||
|
init_weight(module.q_proj, stds["in"])
|
||||||
|
init_weight(module.k_proj, stds["in"])
|
||||||
|
init_weight(module.v_proj, stds["in"])
|
||||||
|
init_weight(module.Wo, stds["out"])
|
||||||
|
elif isinstance(module, ModernBertPredictionHead):
|
||||||
|
init_weight(module.dense, stds["out"])
|
||||||
|
elif isinstance(module, ModernBertDecoderForSequenceClassification):
|
||||||
|
init_weight(module.classifier, stds["final_out"])
|
||||||
|
elif isinstance(module, ModernBertDecoderForCausalLM):
|
||||||
|
init_weight(module.decoder, stds["out"])
|
||||||
|
elif isinstance(module, nn.LayerNorm):
|
||||||
|
module.weight.data.fill_(1.0)
|
||||||
|
if module.bias is not None:
|
||||||
|
module.bias.data.zero_()
|
||||||
|
|
||||||
|
def resize_token_embeddings(self, *args, **kwargs):
|
||||||
|
model_embeds = super().resize_token_embeddings(*args, **kwargs)
|
||||||
|
|
||||||
|
if self.config.reference_compile in {True, None}:
|
||||||
|
if self.config.reference_compile:
|
||||||
|
logger.warning_once(
|
||||||
|
"Resizing token embeddings with `torch.compile` is not supported. Falling back to non-compiled mode."
|
||||||
|
)
|
||||||
|
self.config.reference_compile = False
|
||||||
|
|
||||||
|
return model_embeds
|
||||||
|
|
||||||
|
|
||||||
|
@auto_docstring
|
||||||
|
class ModernBertDecoderModel(ModernBertDecoderPreTrainedModel):
|
||||||
|
def __init__(self, config: ModernBertDecoderConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
self.config = config
|
||||||
|
self.embeddings = ModernBertEmbeddings(config)
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[ModernBertDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||||
|
)
|
||||||
|
self.final_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
self.global_rotary_emb = ModernBertRotaryEmbedding(
|
||||||
|
config=config, dim=config.hidden_size // config.num_attention_heads, base=config.global_rope_theta
|
||||||
|
)
|
||||||
|
if config.local_rope_theta is not None:
|
||||||
|
self.local_rotary_emb = ModernBertRotaryEmbedding(
|
||||||
|
config=config, dim=config.hidden_size // config.num_attention_heads, base=config.local_rope_theta
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.local_rotary_emb = self.global_rotary_emb
|
||||||
|
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.embeddings.tok_embeddings
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.embeddings.tok_embeddings = value
|
||||||
|
|
||||||
|
@auto_docstring
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[Cache] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Union[tuple[torch.Tensor, ...], BaseModelOutputWithPast]:
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
if (input_ids is None) == (inputs_embeds is None):
|
||||||
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
if input_ids is not None:
|
||||||
|
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
||||||
|
batch_size, seq_length = input_ids.shape[:2]
|
||||||
|
else:
|
||||||
|
batch_size, seq_length = inputs_embeds.shape[:2]
|
||||||
|
|
||||||
|
# Handle past_key_values and cache setup
|
||||||
|
if use_cache and past_key_values is None and not self.training:
|
||||||
|
past_key_values = DynamicCache()
|
||||||
|
|
||||||
|
if cache_position is None:
|
||||||
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
|
cache_position = torch.arange(
|
||||||
|
past_seen_tokens,
|
||||||
|
past_seen_tokens + seq_length,
|
||||||
|
device=input_ids.device if input_ids is not None else inputs_embeds.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = cache_position.unsqueeze(0).expand(batch_size, -1)
|
||||||
|
|
||||||
|
# Calculate embeddings
|
||||||
|
hidden_states = self.embeddings(input_ids=input_ids, inputs_embeds=inputs_embeds)
|
||||||
|
|
||||||
|
# It may already have been prepared by e.g. `generate`
|
||||||
|
if not isinstance(causal_mask_mapping := attention_mask, dict):
|
||||||
|
# Prepare mask arguments
|
||||||
|
mask_kwargs = {
|
||||||
|
"config": self.config,
|
||||||
|
"input_embeds": hidden_states,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"cache_position": cache_position,
|
||||||
|
"past_key_values": past_key_values,
|
||||||
|
}
|
||||||
|
|
||||||
|
causal_mask_mapping = {
|
||||||
|
"full_attention": create_causal_mask(**mask_kwargs),
|
||||||
|
}
|
||||||
|
|
||||||
|
if any(layer_type == "sliding_attention" for layer_type in self.config.layer_types):
|
||||||
|
# NOTE: sliding window numbers matches ModernBERT but is only half of it
|
||||||
|
# +1 is because it is inclusive of that number
|
||||||
|
if hasattr(self.config, "local_attention") and self.config.local_attention is not None:
|
||||||
|
self.config.sliding_window = self.config.local_attention // 2 + 1
|
||||||
|
|
||||||
|
causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
|
||||||
|
else:
|
||||||
|
causal_mask_mapping["sliding_attention"] = causal_mask_mapping["full_attention"]
|
||||||
|
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
all_self_attentions = () if output_attentions else None
|
||||||
|
next_decoder_cache = past_key_values if use_cache else None
|
||||||
|
|
||||||
|
for idx, decoder_layer in enumerate(self.layers):
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
# Get the appropriate rotary embedding for this layer
|
||||||
|
if decoder_layer.attention_type == "sliding_attention":
|
||||||
|
rotary_emb = self.local_rotary_emb
|
||||||
|
else:
|
||||||
|
rotary_emb = self.global_rotary_emb
|
||||||
|
cos, sin = rotary_emb(hidden_states, position_ids)
|
||||||
|
position_embeddings = (cos, sin)
|
||||||
|
|
||||||
|
# Use the appropriate mask for this layer's attention type
|
||||||
|
layer_attention_mask = causal_mask_mapping[decoder_layer.attention_type]
|
||||||
|
|
||||||
|
layer_outputs = decoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
position_embeddings=position_embeddings,
|
||||||
|
attention_mask=layer_attention_mask,
|
||||||
|
past_key_value=next_decoder_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
cache_position=cache_position,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||||
|
|
||||||
|
hidden_states = self.final_norm(hidden_states)
|
||||||
|
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return tuple(
|
||||||
|
v for v in [hidden_states, next_decoder_cache, all_hidden_states, all_self_attentions] if v is not None
|
||||||
|
)
|
||||||
|
return BaseModelOutputWithPast(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=next_decoder_cache,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@auto_docstring(
|
||||||
|
custom_intro="""
|
||||||
|
The ModernBert Decoder Model with a language modeling head on top for causal language modeling (CLM).
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
class ModernBertDecoderForCausalLM(ModernBertDecoderPreTrainedModel, GenerationMixin):
|
||||||
|
_tied_weights_keys = ["decoder.weight"]
|
||||||
|
|
||||||
|
def __init__(self, config: ModernBertDecoderConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
self.config = config
|
||||||
|
self.model = ModernBertDecoderModel(config)
|
||||||
|
self.lm_head = ModernBertPredictionHead(config)
|
||||||
|
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=config.decoder_bias)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.model.embeddings.tok_embeddings
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.model.embeddings.tok_embeddings = value
|
||||||
|
|
||||||
|
def get_output_embeddings(self):
|
||||||
|
return self.decoder
|
||||||
|
|
||||||
|
def set_output_embeddings(self, new_embeddings):
|
||||||
|
self.decoder = new_embeddings
|
||||||
|
|
||||||
|
def set_decoder(self, decoder):
|
||||||
|
self.model = decoder
|
||||||
|
|
||||||
|
def get_decoder(self):
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
@torch.compile(dynamic=True)
|
||||||
|
def compiled_head(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.decoder(self.lm_head(hidden_states))
|
||||||
|
|
||||||
|
@auto_docstring
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[Cache] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Union[tuple, CausalLMOutputWithPast]:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[`~modeling_outputs.CausalLMOutputWithPast`] or `tuple(torch.FloatTensor)`: A
|
||||||
|
[`~modeling_outputs.CausalLMOutputWithPast`] or a tuple of `torch.FloatTensor` (if `return_dict=False`)
|
||||||
|
comprising various elements depending on the configuration and inputs.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoTokenizer, ModernBertDecoderForCausalLM
|
||||||
|
|
||||||
|
>>> model = ModernBertDecoderForCausalLM.from_pretrained("blab-jhu/test-32m-dec")
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained("blab-jhu/test-32m-dec")
|
||||||
|
|
||||||
|
>>> prompt = "The capital of France is"
|
||||||
|
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||||
|
|
||||||
|
>>> # Generate
|
||||||
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=1)
|
||||||
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
|
"The capital of France is Paris"
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
logits = (
|
||||||
|
self.compiled_head(hidden_states)
|
||||||
|
if self.config.reference_compile
|
||||||
|
else self.decoder(self.lm_head(hidden_states))
|
||||||
|
)
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
# Shift so that tokens < n predict n
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
# Flatten the tokens
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||||
|
shift_labels = shift_labels.view(-1)
|
||||||
|
# Enable model parallelism
|
||||||
|
shift_labels = shift_labels.to(shift_logits.device)
|
||||||
|
loss = loss_fct(shift_logits, shift_labels)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return CausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _reorder_cache(past_key_values, beam_idx):
|
||||||
|
reordered_past = ()
|
||||||
|
for layer_past in past_key_values:
|
||||||
|
reordered_past += (
|
||||||
|
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||||
|
)
|
||||||
|
return reordered_past
|
||||||
|
|
||||||
|
|
||||||
|
@auto_docstring(
|
||||||
|
custom_intro="""
|
||||||
|
The ModernBert Decoder Model with a sequence classification head on top (linear layer).
|
||||||
|
|
||||||
|
[`ModernBertDecoderForSequenceClassification`] uses the last token in order to do the classification, as other causal models
|
||||||
|
(e.g. GPT-1, GPT-2) do.
|
||||||
|
|
||||||
|
Since it does classification on the last token, it requires to know the position of the last token. If a
|
||||||
|
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
|
||||||
|
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
|
||||||
|
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
||||||
|
each row of the batch).
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
class ModernBertDecoderForSequenceClassification(ModernBertDecoderPreTrainedModel):
|
||||||
|
def __init__(self, config: ModernBertDecoderConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
self.num_labels = config.num_labels
|
||||||
|
self.model = ModernBertDecoderModel(config)
|
||||||
|
|
||||||
|
self.head = ModernBertPredictionHead(config)
|
||||||
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels, bias=config.classifier_bias)
|
||||||
|
self.drop = torch.nn.Dropout(config.classifier_dropout)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.model.embeddings.tok_embeddings
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.model.embeddings.tok_embeddings = value
|
||||||
|
|
||||||
|
def get_output_embeddings(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def set_output_embeddings(self, new_embeddings):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def set_decoder(self, decoder):
|
||||||
|
self.model = decoder
|
||||||
|
|
||||||
|
def get_decoder(self):
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
@auto_docstring(checkpoint="blab-jhu/test-32m-dec")
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[Cache] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Union[tuple, SequenceClassifierOutputWithPast]:
|
||||||
|
r"""
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||||
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||||
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||||
|
"""
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
transformer_outputs = self.model(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
hidden_states = transformer_outputs[0]
|
||||||
|
hidden_states = self.drop(self.head(hidden_states))
|
||||||
|
logits = self.classifier(hidden_states)
|
||||||
|
|
||||||
|
if input_ids is not None:
|
||||||
|
batch_size, sequence_length = input_ids.shape[:2]
|
||||||
|
else:
|
||||||
|
batch_size, sequence_length = inputs_embeds.shape[:2]
|
||||||
|
|
||||||
|
if self.config.pad_token_id is None and batch_size != 1:
|
||||||
|
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||||
|
if self.config.pad_token_id is None:
|
||||||
|
last_non_pad_token = -1
|
||||||
|
elif input_ids is not None:
|
||||||
|
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
|
||||||
|
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
|
||||||
|
token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
|
||||||
|
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
|
||||||
|
else:
|
||||||
|
last_non_pad_token = -1
|
||||||
|
logger.warning_once(
|
||||||
|
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||||
|
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||||
|
)
|
||||||
|
|
||||||
|
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
if self.config.problem_type is None:
|
||||||
|
if self.num_labels == 1:
|
||||||
|
self.config.problem_type = "regression"
|
||||||
|
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||||
|
self.config.problem_type = "single_label_classification"
|
||||||
|
else:
|
||||||
|
self.config.problem_type = "multi_label_classification"
|
||||||
|
|
||||||
|
if self.config.problem_type == "regression":
|
||||||
|
loss_fct = MSELoss()
|
||||||
|
if self.num_labels == 1:
|
||||||
|
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
||||||
|
else:
|
||||||
|
loss = loss_fct(pooled_logits, labels)
|
||||||
|
elif self.config.problem_type == "single_label_classification":
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
|
elif self.config.problem_type == "multi_label_classification":
|
||||||
|
loss_fct = BCEWithLogitsLoss()
|
||||||
|
loss = loss_fct(pooled_logits, labels)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (pooled_logits,) + transformer_outputs[1:]
|
||||||
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
|
return SequenceClassifierOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=pooled_logits,
|
||||||
|
past_key_values=transformer_outputs.past_key_values,
|
||||||
|
hidden_states=transformer_outputs.hidden_states,
|
||||||
|
attentions=transformer_outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ModernBertDecoderConfig",
|
||||||
|
"ModernBertDecoderModel",
|
||||||
|
"ModernBertDecoderPreTrainedModel",
|
||||||
|
"ModernBertDecoderForCausalLM",
|
||||||
|
"ModernBertDecoderForSequenceClassification",
|
||||||
|
]
|
0
tests/models/modernbert_decoder/__init__.py
Normal file
0
tests/models/modernbert_decoder/__init__.py
Normal file
@ -0,0 +1,199 @@
|
|||||||
|
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer, ModernBertDecoderConfig, is_torch_available
|
||||||
|
from transformers.testing_utils import (
|
||||||
|
require_torch,
|
||||||
|
slow,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
|
||||||
|
from ...test_modeling_common import _config_zero_init
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
ModernBertDecoderForCausalLM,
|
||||||
|
ModernBertDecoderForSequenceClassification,
|
||||||
|
ModernBertDecoderModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ModernBertDecoderModelTester(CausalLMModelTester):
|
||||||
|
config_class = ModernBertDecoderConfig
|
||||||
|
if is_torch_available():
|
||||||
|
base_model_class = ModernBertDecoderModel
|
||||||
|
causal_lm_class = ModernBertDecoderForCausalLM
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class ModernBertDecoderModelTest(CausalLMModelTest, unittest.TestCase):
|
||||||
|
all_model_classes = (
|
||||||
|
(ModernBertDecoderModel, ModernBertDecoderForCausalLM, ModernBertDecoderForSequenceClassification)
|
||||||
|
if is_torch_available()
|
||||||
|
else ()
|
||||||
|
)
|
||||||
|
pipeline_model_mapping = (
|
||||||
|
{
|
||||||
|
"feature-extraction": ModernBertDecoderModel,
|
||||||
|
"text-generation": ModernBertDecoderForCausalLM,
|
||||||
|
"text-classification": ModernBertDecoderForSequenceClassification,
|
||||||
|
}
|
||||||
|
if is_torch_available()
|
||||||
|
else {}
|
||||||
|
)
|
||||||
|
|
||||||
|
test_head_masking = False
|
||||||
|
test_pruning = False
|
||||||
|
model_tester_class = ModernBertDecoderModelTester
|
||||||
|
|
||||||
|
def test_initialization(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
configs_no_init = _config_zero_init(config)
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config=configs_no_init)
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
# The classifier.weight from ModernBertDecoderForSequenceClassification
|
||||||
|
# is initialized without `initializer_range`, so it's not set to ~0 via the _config_zero_init
|
||||||
|
if param.requires_grad and not (
|
||||||
|
name == "classifier.weight" and model_class in [ModernBertDecoderForSequenceClassification]
|
||||||
|
):
|
||||||
|
data = torch.flatten(param.data)
|
||||||
|
n_elements = torch.numel(data)
|
||||||
|
# skip 2.5% of elements on each side to avoid issues caused by `nn.init.trunc_normal_` described in
|
||||||
|
# https://github.com/huggingface/transformers/pull/27906#issuecomment-1846951332
|
||||||
|
n_elements_to_skip_on_each_side = int(n_elements * 0.025)
|
||||||
|
data_to_check = torch.sort(data).values
|
||||||
|
if n_elements_to_skip_on_each_side > 0:
|
||||||
|
data_to_check = data_to_check[n_elements_to_skip_on_each_side:-n_elements_to_skip_on_each_side]
|
||||||
|
self.assertIn(
|
||||||
|
((data_to_check.mean() * 1e9).round() / 1e9).item(),
|
||||||
|
[0.0, 1.0],
|
||||||
|
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch
|
||||||
|
class ModernBertDecoderIntegrationTest(unittest.TestCase):
|
||||||
|
def test_inference_causal_lm(self):
|
||||||
|
if version.parse(torch.__version__) < version.parse("2.4.0"):
|
||||||
|
self.skipTest(reason="This test requires torch >= 2.4 to run.")
|
||||||
|
|
||||||
|
model = ModernBertDecoderForCausalLM.from_pretrained(
|
||||||
|
"blab-jhu/test-32m-dec", reference_compile=False, attn_implementation="sdpa"
|
||||||
|
)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("blab-jhu/test-32m-dec")
|
||||||
|
|
||||||
|
inputs = tokenizer("Paris is the capital of", return_tensors="pt")
|
||||||
|
with torch.no_grad():
|
||||||
|
output = model(**inputs)[0]
|
||||||
|
expected_shape = torch.Size((1, 6, model.config.vocab_size))
|
||||||
|
self.assertEqual(output.shape, expected_shape)
|
||||||
|
|
||||||
|
# compare the actual values for a slice.
|
||||||
|
expected_slice = torch.tensor(
|
||||||
|
[[[-8.0183, -7.1578, -0.4453], [-6.2909, -6.1557, 4.9063], [-6.7689, -5.8068, 6.1078]]]
|
||||||
|
)
|
||||||
|
torch.testing.assert_close(output[:, :3, :3], expected_slice, rtol=1e-4, atol=1e-4)
|
||||||
|
|
||||||
|
def test_inference_no_head(self):
|
||||||
|
if version.parse(torch.__version__) < version.parse("2.4.0"):
|
||||||
|
self.skipTest(reason="This test requires torch >= 2.4 to run.")
|
||||||
|
|
||||||
|
model = ModernBertDecoderModel.from_pretrained(
|
||||||
|
"blab-jhu/test-32m-dec", reference_compile=False, attn_implementation="sdpa"
|
||||||
|
)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("blab-jhu/test-32m-dec")
|
||||||
|
|
||||||
|
inputs = tokenizer("Paris is the capital of", return_tensors="pt")
|
||||||
|
with torch.no_grad():
|
||||||
|
output = model(**inputs)[0]
|
||||||
|
expected_shape = torch.Size((1, 6, model.config.hidden_size))
|
||||||
|
self.assertEqual(output.shape, expected_shape)
|
||||||
|
|
||||||
|
# compare the actual values for a slice.
|
||||||
|
expected_slice = torch.tensor(
|
||||||
|
[[[0.3151, -0.6417, -0.7027], [-0.7834, -1.5810, 0.4576], [1.0614, -0.7268, -0.0871]]]
|
||||||
|
)
|
||||||
|
torch.testing.assert_close(output[:, :3, :3], expected_slice, rtol=1e-4, atol=1e-4)
|
||||||
|
|
||||||
|
def test_generation(self):
|
||||||
|
if version.parse(torch.__version__) < version.parse("2.4.0"):
|
||||||
|
self.skipTest(reason="This test requires torch >= 2.4 to run.")
|
||||||
|
|
||||||
|
model = ModernBertDecoderForCausalLM.from_pretrained("blab-jhu/test-32m-dec")
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("blab-jhu/test-32m-dec")
|
||||||
|
|
||||||
|
inputs = tokenizer("The weather today is", return_tensors="pt")
|
||||||
|
outputs = model.generate(**inputs, max_new_tokens=10, do_sample=False)
|
||||||
|
output_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||||
|
|
||||||
|
# Check that we got some reasonable output
|
||||||
|
self.assertEqual(len(output_text), 1)
|
||||||
|
self.assertTrue(len(output_text[0]) > len("The weather today is"))
|
||||||
|
|
||||||
|
def test_sliding_window_long_context(self):
|
||||||
|
"""
|
||||||
|
Test that ModernBertDecoder works with sliding window attention for longer sequences.
|
||||||
|
"""
|
||||||
|
if version.parse(torch.__version__) < version.parse("2.4.0"):
|
||||||
|
self.skipTest(reason="This test requires torch >= 2.4 to run.")
|
||||||
|
|
||||||
|
model = ModernBertDecoderForCausalLM.from_pretrained("blab-jhu/test-32m-dec")
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("blab-jhu/test-32m-dec")
|
||||||
|
|
||||||
|
# Create a longer input to test sliding window attention
|
||||||
|
long_input = "This is a test. " * 50 # Repeat to make it longer
|
||||||
|
inputs = tokenizer(long_input, return_tensors="pt", truncation=True, max_length=512)
|
||||||
|
|
||||||
|
outputs = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||||
|
|
||||||
|
# Check that generation worked with longer context
|
||||||
|
self.assertEqual(outputs.shape[0], 1)
|
||||||
|
self.assertGreater(outputs.shape[1], inputs["input_ids"].shape[1])
|
||||||
|
|
||||||
|
def test_sequence_classification(self):
|
||||||
|
"""
|
||||||
|
Test that ModernBertDecoderForSequenceClassification works correctly.
|
||||||
|
"""
|
||||||
|
if version.parse(torch.__version__) < version.parse("2.4.0"):
|
||||||
|
self.skipTest(reason="This test requires torch >= 2.4 to run.")
|
||||||
|
|
||||||
|
model = ModernBertDecoderForSequenceClassification.from_pretrained("blab-jhu/test-32m-dec", num_labels=2)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("blab-jhu/test-32m-dec")
|
||||||
|
|
||||||
|
# Test with sample input
|
||||||
|
inputs = tokenizer("This is a positive example.", return_tensors="pt")
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(**inputs)
|
||||||
|
|
||||||
|
# Check output shape
|
||||||
|
expected_shape = (1, 2) # batch_size=1, num_labels=2
|
||||||
|
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||||
|
|
||||||
|
# Test with labels
|
||||||
|
labels = torch.tensor([1])
|
||||||
|
outputs_with_loss = model(**inputs, labels=labels)
|
||||||
|
|
||||||
|
# Check that loss is computed
|
||||||
|
self.assertIsNotNone(outputs_with_loss.loss)
|
||||||
|
self.assertTrue(isinstance(outputs_with_loss.loss.item(), float))
|
@ -3309,6 +3309,8 @@ class ModelTesterMixin:
|
|||||||
"ModernBertForTokenClassification",
|
"ModernBertForTokenClassification",
|
||||||
"TimmWrapperForImageClassification",
|
"TimmWrapperForImageClassification",
|
||||||
"ModernBertForQuestionAnswering",
|
"ModernBertForQuestionAnswering",
|
||||||
|
"ModernBertDecoderForSequenceClassification",
|
||||||
|
"ModernBertDecoderForCausalLM",
|
||||||
]
|
]
|
||||||
special_param_names = [
|
special_param_names = [
|
||||||
r"^bit\.",
|
r"^bit\.",
|
||||||
|
@ -276,6 +276,17 @@ SPECIAL_CASES_TO_ALLOW = {
|
|||||||
"attention_chunk_size",
|
"attention_chunk_size",
|
||||||
],
|
],
|
||||||
"Llama4VisionConfig": ["multi_modal_projector_bias", "norm_eps"],
|
"Llama4VisionConfig": ["multi_modal_projector_bias", "norm_eps"],
|
||||||
|
"ModernBertDecoderConfig": [
|
||||||
|
"embedding_dropout",
|
||||||
|
"hidden_activation",
|
||||||
|
"initializer_cutoff_factor",
|
||||||
|
"intermediate_size",
|
||||||
|
"max_position_embeddings",
|
||||||
|
"mlp_bias",
|
||||||
|
"mlp_dropout",
|
||||||
|
"classifier_activation",
|
||||||
|
"global_attn_every_n_layers",
|
||||||
|
],
|
||||||
# position_embedding_type not used and deprecated. Should be deleted in v4.55
|
# position_embedding_type not used and deprecated. Should be deleted in v4.55
|
||||||
"LayoutLMConfig": ["position_embedding_type"],
|
"LayoutLMConfig": ["position_embedding_type"],
|
||||||
"MarkupLMConfig": ["position_embedding_type"],
|
"MarkupLMConfig": ["position_embedding_type"],
|
||||||
|
Loading…
Reference in New Issue
Block a user