mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 21:30:07 +06:00

* initial cut of modernbert for transformers
* small bug fixes
* fixes
* Update import
* Use compiled mlp->mlp_norm to match research implementation
* Propagate changes in modular to modeling
* Replace duplicate attn_out_dropout in favor of attention_dropout
cc @warner-benjamin let me know if the two should remain separate!
* Update BOS to CLS and EOS to SEP
Please confirm @warner-benjamin
* Set default classifier bias to False, matching research repo
* Update tie_word_embeddings description
* Fix _init_weights for ForMaskedLM
* Match base_model_prefix
* Add compiled_head to match research repo outputs
* Fix imports for ModernBertForMaskedLM
* Just use "gelu" default outright for classifier
* Fix config name typo: initalizer -> initializer
* Remove some unused parameters in docstring. Still lots to edit there!
* Compile the embeddings forward
Not having this resulted in very slight differences - so small it wasn't even noticed for the base model, only for the large model.
But the tiny difference for large propagated at the embedding layer through the rest of the model, leading to notable differences of ~0.0084 average per value, up to 0.2343 for the worst case.
* Add drafts for ForSequenceClassification/ForTokenClassification
* Add initial SDPA support (not exactly equivalent to FA2 yet!)
During testing, FA2 and SDPA still differ by about 0.0098 per value in the token embeddings. It still predicts the correct mask fills, but I'd like to get it fully 1-1 if possible.
* Only use attention dropout if training
* Add initial eager attention support (also not equivalent to FA2 yet!)
Frustratingly, I also can't get eager to be equivalent to FA2 (or sdpa), but it does get really close, i.e. avg ~0.010 difference per value.
Especially if I use fp32 for both FA2&eager, avg ~0.0029 difference per value
The fill-mask results are good with eager.
* Add initial tests, output_attentions, output_hidden_states, prune_heads
Tests are based on BERT, not all tests pass yet: 23 failed, 79 passed, 100 skipped
* Remove kwargs from ModernBertForMaskedLM
Disable sparse_prediction by default to match the normal HF, can be enabled via config
* Remove/adjust/skip improper tests; warn if padding but no attn mask
* Run formatting etc.
* Run python utils/custom_init_isort.py
* FlexAttention with unpadded sequences(matches FA2 within bf16 numerics)
* Reformat init_weights based on review
* self -> module in attention forwards
* Remove if config.tie_word_embeddings
* Reformat output projection on a different line
* Remove pruning
* Remove assert
* Call contiguous() to simplify paths
* Remove prune_qkv_linear_layer
* Format code
* Keep as kwargs, only use if needed
* Remove unused codepaths & related config options
* Remove 3d attn_mask test; fix token classification tuple output
* Reorder: attention_mask above position_ids, fixes gradient checkpointing
* Fix usage if no FA2 or torch v2.5+
* Make torch.compile/triton optional
Should we rename 'compile'? It's a bit vague
* Separate pooling options into separate functions (cls, mean) - cls as default
* Simplify _pad_modernbert_output, remove unused labels path
* Update tied weights to remove decoder.weight, simplify decoder loading
* Adaptively set config.compile based on hf_device_map/device/resize, etc.
* Update ModernBertConfig docstring
* Satisfy some consistency checks, add unfinished docs
* Only set compile to False if there's more than 1 device
* Add docstrings for public ModernBert classes
* Dont replace docstring returns - ends up being duplicate
* Fix mistake in toctree
* Reformat toctree
* Patched FlexAttention, SDPA, Eager with Local Attention
* Implement FA2 -> SDPA -> Eager attn_impl defaulting, crucial
both to match the original performance, and to get the highest inference speed without requiring users to manually pick FA2
* Patch test edge case with Idefics3 not working with 'attn_implementation="sdpa"'
* Repad all_hidden_states as well
* rename config.compile to reference_compile
* disable flex_attention since it crashes
* Update modernbert.md
* Using dtype min to mask in eager
* Fully remove flex attention for now
It's only compatible with the nightly torch 2.6, so we'll leave it be for now. It's also slower than eager/sdpa.
Also, update compile -> reference_compile in one more case
* Call contiguous to allow for .view()
* Copyright 2020 -> 2024
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
* Update/simplify __init__ structure
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
* Remove "... if dropout_prob > 0 else identity"
As dropout with 0.0 should be efficient like identity
* re-use existing pad/unpad functions instead of creating new ones
* remove flexattention method
* Compute attention_mask and local_attention_mask once in modeling
* Simplify sequence classification prediction heads, only CLS now
Users can make custom heads if they feel like it
Also removes the unnecessary pool parameter
* Simplify module.training in eager attn
* Also export ModernBertPreTrainedModel
* Update the documentation with links to finetuning scripts
* Explain local_attention_mask parameter in docstring
* Simplify _autoset_attn_implementation, rely on super()
* Keep "in" to initialize Prediction head
Doublechecked with Benjamin that it's correct/what we used for pretraining
* add back mean pooling
* Use the pooling head in TokenClassification
* update copyright
* Reset config._attn_implementation_internal on failure
* Allow optional attention_mask in ForMaskedLM head
* fix failing run_slow tests
* Add links to the paper
* Remove unpad_no_grad, always pad/unpad without gradients
* local_attention_mask -> sliding_window_mask
* Revert "Use the pooling head in TokenClassification"
This reverts commit 99c38badd1
.
There was no real motivation, no info on whether having this bigger head does anything useful.
* Simplify pooling, 2 options via if-else
---------
Co-authored-by: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com>
Co-authored-by: Tom Aarsen <Cubiegamedev@gmail.com>
Co-authored-by: Said Taghadouini <taghadouinisaid@gmail.com>
Co-authored-by: Benjamin Clavié <ben@clavie.eu>
Co-authored-by: Antoine Chaffin <ant54600@hotmail.fr>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
368 lines
14 KiB
Python
368 lines
14 KiB
Python
# coding=utf-8
|
|
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import os
|
|
import unittest
|
|
|
|
import pytest
|
|
|
|
from transformers import ModernBertConfig, is_torch_available
|
|
from transformers.models.auto import get_values
|
|
from transformers.testing_utils import (
|
|
CaptureLogger,
|
|
require_flash_attn,
|
|
require_torch,
|
|
require_torch_gpu,
|
|
slow,
|
|
torch_device,
|
|
)
|
|
|
|
from ...generation.test_utils import GenerationTesterMixin
|
|
from ...test_configuration_common import ConfigTester
|
|
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, ids_tensor, random_attention_mask
|
|
from ...test_pipeline_mixin import PipelineTesterMixin
|
|
|
|
|
|
if is_torch_available():
|
|
import torch
|
|
|
|
from transformers import (
|
|
MODEL_FOR_PRETRAINING_MAPPING,
|
|
ModernBertForMaskedLM,
|
|
ModernBertForSequenceClassification,
|
|
ModernBertForTokenClassification,
|
|
ModernBertModel,
|
|
logging,
|
|
)
|
|
|
|
|
|
class ModernBertModelTester:
|
|
def __init__(
|
|
self,
|
|
parent,
|
|
batch_size=13,
|
|
seq_length=7,
|
|
is_training=True,
|
|
use_input_mask=True,
|
|
use_labels=True,
|
|
vocab_size=99,
|
|
pad_token_id=0,
|
|
hidden_size=32,
|
|
num_hidden_layers=2,
|
|
num_attention_heads=4,
|
|
intermediate_size=37,
|
|
hidden_activation="gelu",
|
|
mlp_dropout=0.0,
|
|
attention_dropout=0.0,
|
|
embedding_dropout=0.0,
|
|
classifier_dropout=0.0,
|
|
max_position_embeddings=512,
|
|
type_vocab_size=16,
|
|
type_sequence_label_size=2,
|
|
initializer_range=0.02,
|
|
num_labels=3,
|
|
num_choices=4,
|
|
scope=None,
|
|
):
|
|
self.parent = parent
|
|
self.batch_size = batch_size
|
|
self.seq_length = seq_length
|
|
self.is_training = is_training
|
|
self.use_input_mask = use_input_mask
|
|
self.use_labels = use_labels
|
|
self.vocab_size = vocab_size
|
|
self.pad_token_id = pad_token_id
|
|
self.hidden_size = hidden_size
|
|
self.num_hidden_layers = num_hidden_layers
|
|
self.num_attention_heads = num_attention_heads
|
|
self.intermediate_size = intermediate_size
|
|
self.hidden_activation = hidden_activation
|
|
self.mlp_dropout = mlp_dropout
|
|
self.attention_dropout = attention_dropout
|
|
self.embedding_dropout = embedding_dropout
|
|
self.classifier_dropout = classifier_dropout
|
|
self.max_position_embeddings = max_position_embeddings
|
|
self.type_vocab_size = type_vocab_size
|
|
self.type_sequence_label_size = type_sequence_label_size
|
|
self.initializer_range = initializer_range
|
|
self.num_labels = num_labels
|
|
self.num_choices = num_choices
|
|
self.scope = scope
|
|
|
|
def prepare_config_and_inputs(self):
|
|
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
|
|
|
input_mask = None
|
|
if self.use_input_mask:
|
|
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
|
|
|
sequence_labels = None
|
|
token_labels = None
|
|
choice_labels = None
|
|
if self.use_labels:
|
|
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
|
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
|
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
|
|
|
config = self.get_config()
|
|
|
|
return config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
|
|
def get_config(self):
|
|
"""
|
|
Returns a tiny configuration by default.
|
|
"""
|
|
config = ModernBertConfig(
|
|
vocab_size=self.vocab_size,
|
|
pad_token_id=self.pad_token_id,
|
|
hidden_size=self.hidden_size,
|
|
num_hidden_layers=self.num_hidden_layers,
|
|
num_attention_heads=self.num_attention_heads,
|
|
intermediate_size=self.intermediate_size,
|
|
hidden_activation=self.hidden_activation,
|
|
mlp_dropout=self.mlp_dropout,
|
|
attention_dropout=self.attention_dropout,
|
|
embedding_dropout=self.embedding_dropout,
|
|
classifier_dropout=self.classifier_dropout,
|
|
max_position_embeddings=self.max_position_embeddings,
|
|
type_vocab_size=self.type_vocab_size,
|
|
is_decoder=False,
|
|
initializer_range=self.initializer_range,
|
|
)
|
|
if test := os.environ.get("PYTEST_CURRENT_TEST", False):
|
|
test_name = test.split(":")[-1].split(" ")[0]
|
|
|
|
# If we're testing `test_retain_grad_hidden_states_attentions`, we normally get an error
|
|
# that compilation doesn't work. Users can then set compile=False when loading the model,
|
|
# much like here. We're testing whether it works once they've done that.
|
|
if test_name == "test_retain_grad_hidden_states_attentions":
|
|
config.reference_compile = False
|
|
# Some tests require attentions to be outputted, in that case we'll set the attention implementation to eager
|
|
# as the others don't support outputted attentions
|
|
if test_name in (
|
|
"test_attention_outputs",
|
|
"test_hidden_states_output",
|
|
"test_retain_grad_hidden_states_attentions",
|
|
):
|
|
config._attn_implementation = "eager"
|
|
return config
|
|
|
|
def create_and_check_model(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
|
model = ModernBertModel(config=config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
result = model(input_ids, attention_mask=input_mask)
|
|
result = model(input_ids)
|
|
result = model(input_ids)
|
|
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
|
|
|
def create_and_check_for_masked_lm(
|
|
self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
):
|
|
model = ModernBertForMaskedLM(config=config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
|
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
|
|
|
def create_and_check_for_sequence_classification(
|
|
self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
):
|
|
config.num_labels = self.num_labels
|
|
model = ModernBertForSequenceClassification(config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
result = model(input_ids, attention_mask=input_mask, labels=sequence_labels)
|
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
|
|
|
|
def create_and_check_for_token_classification(
|
|
self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
):
|
|
config.num_labels = self.num_labels
|
|
model = ModernBertForTokenClassification(config=config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
|
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
|
|
|
|
def prepare_config_and_inputs_for_common(self):
|
|
config_and_inputs = self.prepare_config_and_inputs()
|
|
(
|
|
config,
|
|
input_ids,
|
|
input_mask,
|
|
sequence_labels,
|
|
token_labels,
|
|
choice_labels,
|
|
) = config_and_inputs
|
|
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
|
return config, inputs_dict
|
|
|
|
|
|
@require_torch
|
|
class ModernBertModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|
test_torchscript = False
|
|
|
|
all_model_classes = (
|
|
(
|
|
ModernBertModel,
|
|
ModernBertForMaskedLM,
|
|
ModernBertForSequenceClassification,
|
|
ModernBertForTokenClassification,
|
|
)
|
|
if is_torch_available()
|
|
else ()
|
|
)
|
|
all_generative_model_classes = ()
|
|
pipeline_model_mapping = (
|
|
{
|
|
"feature-extraction": ModernBertModel,
|
|
"fill-mask": ModernBertForMaskedLM,
|
|
"text-classification": ModernBertForSequenceClassification,
|
|
"token-classification": ModernBertForTokenClassification,
|
|
"zero-shot": ModernBertForSequenceClassification,
|
|
}
|
|
if is_torch_available()
|
|
else {}
|
|
)
|
|
fx_compatible = False
|
|
test_head_masking = False
|
|
test_pruning = False
|
|
model_split_percents = [0.5, 0.8, 0.9]
|
|
|
|
# special case for ForPreTraining model
|
|
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
|
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
|
|
|
if inputs_dict.get("output_attentions", False):
|
|
inputs_dict["output_attentions"] = True
|
|
|
|
if return_labels:
|
|
if model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING):
|
|
inputs_dict["labels"] = torch.zeros(
|
|
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
|
|
)
|
|
inputs_dict["next_sentence_label"] = torch.zeros(
|
|
self.model_tester.batch_size, dtype=torch.long, device=torch_device
|
|
)
|
|
return inputs_dict
|
|
|
|
def setUp(self):
|
|
self.model_tester = ModernBertModelTester(self)
|
|
self.config_tester = ConfigTester(self, config_class=ModernBertConfig, hidden_size=37)
|
|
|
|
def test_config(self):
|
|
self.config_tester.run_common_tests()
|
|
|
|
def test_model(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
|
|
|
def test_model_various_embeddings(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
for type in ["absolute", "relative_key", "relative_key_query"]:
|
|
config_and_inputs[0].position_embedding_type = type
|
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
|
|
|
def test_initialization(self):
|
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
|
|
configs_no_init = _config_zero_init(config)
|
|
for model_class in self.all_model_classes:
|
|
model = model_class(config=configs_no_init)
|
|
for name, param in model.named_parameters():
|
|
# The classifier.weight from ModernBertForSequenceClassification and ModernBertForTokenClassification
|
|
# are initialized without `initializer_range`, so they're not set to ~0 via the _config_zero_init
|
|
if param.requires_grad and not (
|
|
name == "classifier.weight"
|
|
and model_class in [ModernBertForSequenceClassification, ModernBertForTokenClassification]
|
|
):
|
|
self.assertIn(
|
|
((param.data.mean() * 1e9).round() / 1e9).item(),
|
|
[0.0, 1.0],
|
|
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
|
)
|
|
|
|
@unittest.skip("ModernBert doesn't use `inputs_embeds` as input.")
|
|
def test_inputs_embeds(self):
|
|
pass
|
|
|
|
def test_for_masked_lm(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
|
|
|
|
def test_for_sequence_classification(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
|
|
|
|
def test_for_token_classification(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
|
|
|
|
def test_for_warning_if_padding_and_no_attention_mask(self):
|
|
(
|
|
config,
|
|
input_ids,
|
|
input_mask,
|
|
sequence_labels,
|
|
token_labels,
|
|
choice_labels,
|
|
) = self.model_tester.prepare_config_and_inputs()
|
|
|
|
# Set pad tokens in the input_ids
|
|
input_ids[0, 0] = config.pad_token_id
|
|
|
|
# Check for warnings if the attention_mask is missing.
|
|
logger = logging.get_logger("transformers.modeling_utils")
|
|
# clear cache so we can test the warning is emitted (from `warning_once`).
|
|
logger.warning_once.cache_clear()
|
|
|
|
with CaptureLogger(logger) as cl:
|
|
model = ModernBertModel(config=config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
model(input_ids, attention_mask=None)
|
|
self.assertIn("We strongly recommend passing in an `attention_mask`", cl.out)
|
|
|
|
@unittest.skip("ModernBert doesn't use separate classes for SDPA, but a function instead.")
|
|
def test_sdpa_can_dispatch_non_composite_models(self):
|
|
pass
|
|
|
|
@slow
|
|
def test_model_from_pretrained(self):
|
|
model_name = "google-bert/bert-base-uncased"
|
|
model = ModernBertModel.from_pretrained(model_name)
|
|
self.assertIsNotNone(model)
|
|
|
|
@require_flash_attn
|
|
@require_torch_gpu
|
|
@pytest.mark.flash_attn_test
|
|
@slow
|
|
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
|
self.skipTest(reason="ModernBert flash attention does not support right padding")
|
|
|
|
@require_flash_attn
|
|
@require_torch_gpu
|
|
@pytest.mark.flash_attn_test
|
|
@slow
|
|
def test_flash_attn_2_conversion(self):
|
|
self.skipTest(reason="ModernBert doesn't use the ModernBertFlashAttention2 class method.")
|
|
|
|
|
|
@require_torch
|
|
class ModernBertModelIntegrationTest(unittest.TestCase):
|
|
"""
|
|
These still need to be written, once public models are available.
|
|
"""
|