mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
ModernBERT bug fixes (#35404)
* bug fixes
* organize imports
* wrap cpu warning in reference_compile
* Avoid needing repad_logits_with_grad, always repad with grads when training
I'm not 100% that the conditional with "or labels is None" makes sense though - not sure what the intention is there. Perhaps we can remove that?
* Revert "Avoid needing repad_logits_with_grad, always repad with grads when training"
This reverts commit cedcb4e89b
.
* Fix grammar: keep -> keeps
* Propagate grammar fix with modular_model_converter
---------
Co-authored-by: Tom Aarsen <Cubiegamedev@gmail.com>
Co-authored-by: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com>
This commit is contained in:
parent
e97d7a5be5
commit
1e3ddcb2d0
@ -505,7 +505,7 @@
|
||||
- local: model_doc/mobilebert
|
||||
title: MobileBERT
|
||||
- local: model_doc/modernbert
|
||||
title: ModernBert
|
||||
title: ModernBERT
|
||||
- local: model_doc/mpnet
|
||||
title: MPNet
|
||||
- local: model_doc/mpt
|
||||
|
@ -14,7 +14,7 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# ModernBert
|
||||
# ModernBERT
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<a href="https://huggingface.co/models?filter=modernbert">
|
||||
@ -27,7 +27,7 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
## Overview
|
||||
|
||||
The ModernBert model was proposed in [Smarter, Better, Faster, Longer: A Modern Bidirectional Encoder for Fast, Memory Efficient, and Long Context Finetuning and Inference](https://arxiv.org/abs/2412.13663) by Benjamin Warner, Antoine Chaffin, Benjamin Clavié, Orion Weller, Oskar Hallström, Said Taghadouini, Alexis Galalgher, Raja Bisas, Faisal Ladhak, Tom Aarsen, Nathan Cooper, Grifin Adams, Jeremy Howard and Iacopo Poli.
|
||||
The ModernBERT model was proposed in [Smarter, Better, Faster, Longer: A Modern Bidirectional Encoder for Fast, Memory Efficient, and Long Context Finetuning and Inference](https://arxiv.org/abs/2412.13663) by Benjamin Warner, Antoine Chaffin, Benjamin Clavié, Orion Weller, Oskar Hallström, Said Taghadouini, Alexis Galalgher, Raja Bisas, Faisal Ladhak, Tom Aarsen, Nathan Cooper, Grifin Adams, Jeremy Howard and Iacopo Poli.
|
||||
|
||||
It is a refresh of the traditional encoder architecture, as used in previous models such as [BERT](https://huggingface.co/docs/transformers/en/model_doc/bert) and [RoBERTa](https://huggingface.co/docs/transformers/en/model_doc/roberta).
|
||||
|
||||
|
@ -109,6 +109,9 @@ class ModernBertConfig(PretrainedConfig):
|
||||
the model will be compiled if 1) `triton` is installed, 2) the model is not on MPS, 3) the model is not
|
||||
shared between devices, and 4) the model is not resized after initialization. If `True`, then the model may
|
||||
be faster in some scenarios.
|
||||
repad_logits_with_grad (`bool`, *optional*, defaults to `False`):
|
||||
When True, ModernBertForMaskedLM keeps track of the logits' gradient when repadding for output. This only
|
||||
applies when using Flash Attention 2 with passed labels. Otherwise output logits always have a gradient.
|
||||
|
||||
Examples:
|
||||
|
||||
@ -164,6 +167,7 @@ class ModernBertConfig(PretrainedConfig):
|
||||
sparse_prediction=False,
|
||||
sparse_pred_ignore_index=-100,
|
||||
reference_compile=None,
|
||||
repad_logits_with_grad=False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
@ -203,6 +207,7 @@ class ModernBertConfig(PretrainedConfig):
|
||||
self.sparse_prediction = sparse_prediction
|
||||
self.sparse_pred_ignore_index = sparse_pred_ignore_index
|
||||
self.reference_compile = reference_compile
|
||||
self.repad_logits_with_grad = repad_logits_with_grad
|
||||
|
||||
if self.classifier_pooling not in ["cls", "mean"]:
|
||||
raise ValueError(
|
||||
|
@ -20,6 +20,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from contextlib import nullcontext
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@ -632,12 +633,14 @@ class ModernBertPreTrainedModel(PreTrainedModel):
|
||||
):
|
||||
# If the user didn't specify anything, try to use flash_attention_2 if available.
|
||||
# Otherwise we fall back to the default SDPA -> Eager from the super() method.
|
||||
# ModernBert's FA2 implementation correctly handles non-fp16/bf16 dtypes, we don't
|
||||
# need the FA2 warning for non-fp16/bf16 dtypes so we set fp16 for the FA2 check.
|
||||
if config._attn_implementation_internal is None:
|
||||
config._attn_implementation_internal = "flash_attention_2"
|
||||
try:
|
||||
return cls._check_and_enable_flash_attn_2(
|
||||
config,
|
||||
torch_dtype=torch_dtype,
|
||||
torch_dtype=torch.float16,
|
||||
device_map=device_map,
|
||||
hard_check_only=False,
|
||||
check_device_map=check_device_map,
|
||||
@ -647,7 +650,7 @@ class ModernBertPreTrainedModel(PreTrainedModel):
|
||||
return super()._autoset_attn_implementation(
|
||||
config,
|
||||
use_flash_attention_2=use_flash_attention_2,
|
||||
torch_dtype=torch_dtype,
|
||||
torch_dtype=torch.float16,
|
||||
device_map=device_map,
|
||||
check_device_map=check_device_map,
|
||||
)
|
||||
@ -672,6 +675,14 @@ class ModernBertPreTrainedModel(PreTrainedModel):
|
||||
)
|
||||
self.config.reference_compile = False
|
||||
|
||||
if self.device.type == "cpu":
|
||||
if self.config.reference_compile:
|
||||
logger.warning_once(
|
||||
"Compiling the model with `torch.compile` and using a `torch.cpu` device is not supported. "
|
||||
"Falling back to non-compiled mode."
|
||||
)
|
||||
self.config.reference_compile = False
|
||||
|
||||
if self.config.reference_compile is None:
|
||||
self.config.reference_compile = is_triton_available()
|
||||
|
||||
@ -763,8 +774,8 @@ def _pad_modernbert_output(
|
||||
MODERNBERT_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
||||
it.
|
||||
Indices of input sequence tokens in the vocabulary. With Flash Attention 2.0, padding will be ignored
|
||||
by default should you provide it.
|
||||
|
||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
@ -790,7 +801,7 @@ MODERNBERT_INPUTS_DOCSTRING = r"""
|
||||
sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
|
||||
perform global attention, while the rest perform local attention. This mask is used to avoid attending to
|
||||
far-away tokens in the local attention layers.
|
||||
far-away tokens in the local attention layers when not using Flash Attention.
|
||||
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
||||
config.n_positions - 1]`.
|
||||
@ -805,11 +816,11 @@ MODERNBERT_INPUTS_DOCSTRING = r"""
|
||||
cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
|
||||
Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
|
||||
max_seqlen (`int`, *optional*):
|
||||
Maximum sequence length in the batch. Used to pad the output tensors.
|
||||
Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
|
||||
batch_size (`int`, *optional*):
|
||||
Batch size of the input sequences. Used to pad the output tensors.
|
||||
seq_len (`int`, *optional*):
|
||||
Sequence length of the input sequences. Used to pad the output tensors.
|
||||
Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||
tensors for more detail.
|
||||
@ -1128,8 +1139,9 @@ class ModernBertForMaskedLM(ModernBertPreTrainedModel):
|
||||
loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size)
|
||||
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
with torch.no_grad():
|
||||
with nullcontext() if self.config.repad_logits_with_grad or labels is None else torch.no_grad():
|
||||
logits = _pad_modernbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,)
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from contextlib import nullcontext
|
||||
from typing import Dict, Literal, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@ -141,6 +142,9 @@ class ModernBertConfig(PretrainedConfig):
|
||||
the model will be compiled if 1) `triton` is installed, 2) the model is not on MPS, 3) the model is not
|
||||
shared between devices, and 4) the model is not resized after initialization. If `True`, then the model may
|
||||
be faster in some scenarios.
|
||||
repad_logits_with_grad (`bool`, *optional*, defaults to `False`):
|
||||
When True, ModernBertForMaskedLM keeps track of the logits' gradient when repadding for output. This only
|
||||
applies when using Flash Attention 2 with passed labels. Otherwise output logits always have a gradient.
|
||||
|
||||
Examples:
|
||||
|
||||
@ -196,6 +200,7 @@ class ModernBertConfig(PretrainedConfig):
|
||||
sparse_prediction=False,
|
||||
sparse_pred_ignore_index=-100,
|
||||
reference_compile=None,
|
||||
repad_logits_with_grad=False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
@ -235,6 +240,7 @@ class ModernBertConfig(PretrainedConfig):
|
||||
self.sparse_prediction = sparse_prediction
|
||||
self.sparse_pred_ignore_index = sparse_pred_ignore_index
|
||||
self.reference_compile = reference_compile
|
||||
self.repad_logits_with_grad = repad_logits_with_grad
|
||||
|
||||
if self.classifier_pooling not in ["cls", "mean"]:
|
||||
raise ValueError(
|
||||
@ -857,12 +863,14 @@ class ModernBertPreTrainedModel(PreTrainedModel):
|
||||
):
|
||||
# If the user didn't specify anything, try to use flash_attention_2 if available.
|
||||
# Otherwise we fall back to the default SDPA -> Eager from the super() method.
|
||||
# ModernBert's FA2 implementation correctly handles non-fp16/bf16 dtypes, we don't
|
||||
# need the FA2 warning for non-fp16/bf16 dtypes so we set fp16 for the FA2 check.
|
||||
if config._attn_implementation_internal is None:
|
||||
config._attn_implementation_internal = "flash_attention_2"
|
||||
try:
|
||||
return cls._check_and_enable_flash_attn_2(
|
||||
config,
|
||||
torch_dtype=torch_dtype,
|
||||
torch_dtype=torch.float16,
|
||||
device_map=device_map,
|
||||
hard_check_only=False,
|
||||
check_device_map=check_device_map,
|
||||
@ -872,7 +880,7 @@ class ModernBertPreTrainedModel(PreTrainedModel):
|
||||
return super()._autoset_attn_implementation(
|
||||
config,
|
||||
use_flash_attention_2=use_flash_attention_2,
|
||||
torch_dtype=torch_dtype,
|
||||
torch_dtype=torch.float16,
|
||||
device_map=device_map,
|
||||
check_device_map=check_device_map,
|
||||
)
|
||||
@ -897,6 +905,14 @@ class ModernBertPreTrainedModel(PreTrainedModel):
|
||||
)
|
||||
self.config.reference_compile = False
|
||||
|
||||
if self.device.type == "cpu":
|
||||
if self.config.reference_compile:
|
||||
logger.warning_once(
|
||||
"Compiling the model with `torch.compile` and using a `torch.cpu` device is not supported. "
|
||||
"Falling back to non-compiled mode."
|
||||
)
|
||||
self.config.reference_compile = False
|
||||
|
||||
if self.config.reference_compile is None:
|
||||
self.config.reference_compile = is_triton_available()
|
||||
|
||||
@ -916,8 +932,8 @@ class ModernBertPreTrainedModel(PreTrainedModel):
|
||||
MODERNBERT_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
||||
it.
|
||||
Indices of input sequence tokens in the vocabulary. With Flash Attention 2.0, padding will be ignored
|
||||
by default should you provide it.
|
||||
|
||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
@ -943,7 +959,7 @@ MODERNBERT_INPUTS_DOCSTRING = r"""
|
||||
sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
|
||||
perform global attention, while the rest perform local attention. This mask is used to avoid attending to
|
||||
far-away tokens in the local attention layers.
|
||||
far-away tokens in the local attention layers when not using Flash Attention.
|
||||
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
||||
config.n_positions - 1]`.
|
||||
@ -958,11 +974,11 @@ MODERNBERT_INPUTS_DOCSTRING = r"""
|
||||
cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
|
||||
Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
|
||||
max_seqlen (`int`, *optional*):
|
||||
Maximum sequence length in the batch. Used to pad the output tensors.
|
||||
Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
|
||||
batch_size (`int`, *optional*):
|
||||
Batch size of the input sequences. Used to pad the output tensors.
|
||||
seq_len (`int`, *optional*):
|
||||
Sequence length of the input sequences. Used to pad the output tensors.
|
||||
Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||
tensors for more detail.
|
||||
@ -1281,8 +1297,9 @@ class ModernBertForMaskedLM(ModernBertPreTrainedModel):
|
||||
loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size)
|
||||
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
with torch.no_grad():
|
||||
with nullcontext() if self.config.repad_logits_with_grad or labels is None else torch.no_grad():
|
||||
logits = _pad_modernbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,)
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
Loading…
Reference in New Issue
Block a user