This commit is contained in:
Arthur 2025-07-03 16:00:28 +02:00
parent cd63172ced
commit cf2e98c9ff
27 changed files with 77 additions and 54 deletions

View File

@ -1925,6 +1925,25 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
- **is_parallelizable** (`bool`) -- A flag indicating whether this model supports model parallelization.
- **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP
models, `pixel_values` for vision models and `input_values` for speech models).
- **can_record_outputs** (dict): Maps output names (e.g., "attentions", "hidden_states")
to either:
- A module class (e.g., `LlamaDecoderLayer`), using default index conventions:
* index=0 for "hidden_states"
* index=1 for "attentions"
- Or an `OutputRecorder(...)` with `target_class`, optional `index`, and `layer_name`.
Examples:
These two are equivalent:
_can_record_outputs = {
"attentions": LlamaAttention,
"hidden_states": LlamaDecoderLayer
}
_can_record_outputs = {
"attentions": OutputRecorder(LlamaAttention, index=1),
"hidden_states": OutputRecorder(LlamaDecoderLayer, index=0)
}
"""
config_class = None

View File

@ -321,8 +321,8 @@ class ArceePreTrainedModel(PreTrainedModel):
_supports_static_cache = True
_supports_attention_backend = True
_can_record_outputs: dict[str, tuple[nn.Module, int]] = {
"hidden_states": (ArceeDecoderLayer, 0),
"attentions": (ArceeAttention, 1),
"hidden_states": ArceeDecoderLayer,
"attentions": ArceeAttention,
}
def _init_weights(self, module):

View File

@ -665,8 +665,8 @@ class AriaPreTrainedModel(PreTrainedModel):
_supports_static_cache = False # MoE models don't work with torch.compile (dynamic slicing)
_supports_attention_backend = True
_can_record_outputs: dict[str, tuple[nn.Module, int]] = {
"hidden_states": (AriaTextDecoderLayer, 0),
"attentions": (AriaTextAttention, 1),
"hidden_states": AriaTextDecoderLayer,
"attentions": AriaTextAttention,
}
def _init_weights(self, module):

View File

@ -316,8 +316,8 @@ class BitNetPreTrainedModel(PreTrainedModel):
_supports_static_cache = True
_supports_attention_backend = True
_can_record_outputs: dict[str, tuple[nn.Module, int]] = {
"hidden_states": (BitNetDecoderLayer, 0),
"attentions": (BitNetAttention, 1),
"hidden_states": BitNetDecoderLayer,
"attentions": BitNetAttention,
}
def _init_weights(self, module):

View File

@ -349,8 +349,8 @@ class CoherePreTrainedModel(PreTrainedModel):
_supports_static_cache = True
_supports_attention_backend = True
_can_record_outputs: dict[str, tuple[nn.Module, int]] = {
"hidden_states": (CohereDecoderLayer, 0),
"attentions": (CohereAttention, 1),
"hidden_states": CohereDecoderLayer,
"attentions": CohereAttention,
}
def _init_weights(self, module):

View File

@ -326,8 +326,8 @@ class Cohere2PreTrainedModel(PreTrainedModel):
_supports_static_cache = True
_supports_attention_backend = True
_can_record_outputs: dict[str, tuple[nn.Module, int]] = {
"hidden_states": (Cohere2DecoderLayer, 0),
"attentions": (Cohere2Attention, 1),
"hidden_states": Cohere2DecoderLayer,
"attentions": Cohere2Attention,
}
def _init_weights(self, module):

View File

@ -502,8 +502,8 @@ class DeepseekV3PreTrainedModel(PreTrainedModel):
_supports_static_cache = True
_supports_attention_backend = True
_can_record_outputs: dict[str, tuple[nn.Module, int]] = {
"hidden_states": (DeepseekV3DecoderLayer, 0),
"attentions": (DeepseekV3Attention, 1),
"hidden_states": DeepseekV3DecoderLayer,
"attentions": DeepseekV3Attention,
}
def _init_weights(self, module):

View File

@ -553,8 +553,8 @@ class DiffLlamaPreTrainedModel(PreTrainedModel):
_supports_static_cache = True
_supports_attention_backend = False
_can_record_outputs: dict[str, tuple[nn.Module, int]] = {
"hidden_states": (DiffLlamaDecoderLayer, 0),
"attentions": (DiffLlamaAttention, 1),
"hidden_states": DiffLlamaDecoderLayer,
"attentions": DiffLlamaAttention,
}
def _init_weights(self, module):

View File

@ -422,8 +422,8 @@ class Dots1PreTrainedModel(PreTrainedModel):
_supports_static_cache = True
_supports_attention_backend = True
_can_record_outputs: dict[str, tuple[nn.Module, int]] = {
"hidden_states": (Dots1DecoderLayer, 0),
"attentions": (Dots1Attention, 1),
"hidden_states": Dots1DecoderLayer,
"attentions": Dots1Attention,
}
def _init_weights(self, module):

View File

@ -318,8 +318,8 @@ class GemmaPreTrainedModel(PreTrainedModel):
_supports_static_cache = True
_supports_attention_backend = True
_can_record_outputs: dict[str, tuple[nn.Module, int]] = {
"hidden_states": (GemmaDecoderLayer, 0),
"attentions": (GemmaAttention, 1),
"hidden_states": GemmaDecoderLayer,
"attentions": GemmaAttention,
}
def _init_weights(self, module):

View File

@ -348,8 +348,8 @@ class Gemma2PreTrainedModel(PreTrainedModel):
_supports_static_cache = True
_supports_attention_backend = True
_can_record_outputs: dict[str, tuple[nn.Module, int]] = {
"hidden_states": (Gemma2DecoderLayer, 0),
"attentions": (Gemma2Attention, 1),
"hidden_states": Gemma2DecoderLayer,
"attentions": Gemma2Attention,
}
def _init_weights(self, module):

View File

@ -438,8 +438,8 @@ class Gemma3PreTrainedModel(PreTrainedModel):
_supports_static_cache = True
_supports_attention_backend = True
_can_record_outputs: dict[str, tuple[nn.Module, int]] = {
"hidden_states": (Gemma3DecoderLayer, 0),
"attentions": (Gemma3Attention, 1),
"hidden_states": Gemma3DecoderLayer,
"attentions": Gemma3Attention,
}
def _init_weights(self, module):

View File

@ -1494,8 +1494,8 @@ class Gemma3nPreTrainedModel(PreTrainedModel):
_supports_static_cache = True
_supports_attention_backend = True
_can_record_outputs: dict[str, tuple[nn.Module, int]] = {
"hidden_states": (Gemma3nTextDecoderLayer, 0),
"attentions": (Gemma3nTextAttention, 1),
"hidden_states": Gemma3nTextDecoderLayer,
"attentions": Gemma3nTextAttention,
}
def _init_weights(self, module):

View File

@ -335,8 +335,8 @@ class GlmPreTrainedModel(PreTrainedModel):
_supports_static_cache = True
_supports_attention_backend = True
_can_record_outputs: dict[str, tuple[nn.Module, int]] = {
"hidden_states": (GlmDecoderLayer, 0),
"attentions": (GlmAttention, 1),
"hidden_states": GlmDecoderLayer,
"attentions": GlmAttention,
}
def _init_weights(self, module):

View File

@ -339,8 +339,8 @@ class Glm4PreTrainedModel(PreTrainedModel):
_supports_static_cache = True
_supports_attention_backend = True
_can_record_outputs: dict[str, tuple[nn.Module, int]] = {
"hidden_states": (Glm4DecoderLayer, 0),
"attentions": (Glm4Attention, 1),
"hidden_states": Glm4DecoderLayer,
"attentions": Glm4Attention,
}
def _init_weights(self, module):

View File

@ -368,8 +368,8 @@ class GPTNeoXPreTrainedModel(PreTrainedModel):
_supports_static_cache = True
_supports_attention_backend = True
_can_record_outputs: dict[str, tuple[nn.Module, int]] = {
"hidden_states": (GPTNeoXDecoderLayer, 0),
"attentions": (GPTNeoXAttention, 1),
"hidden_states": GPTNeoXDecoderLayer,
"attentions": GPTNeoXAttention,
}
_keys_to_ignore_on_load_unexpected = [r"attention.bias", r"attention.masked_bias"]

View File

@ -313,8 +313,8 @@ class GranitePreTrainedModel(PreTrainedModel):
_supports_static_cache = True
_supports_attention_backend = True
_can_record_outputs: dict[str, tuple[nn.Module, int]] = {
"hidden_states": (GraniteDecoderLayer, 0),
"attentions": (GraniteAttention, 1),
"hidden_states": GraniteDecoderLayer,
"attentions": GraniteAttention,
}
def _init_weights(self, module):

View File

@ -320,8 +320,8 @@ class HeliumPreTrainedModel(PreTrainedModel):
_supports_static_cache = True
_supports_attention_backend = True
_can_record_outputs: dict[str, tuple[nn.Module, int]] = {
"hidden_states": (HeliumDecoderLayer, 0),
"attentions": (HeliumAttention, 1),
"hidden_states": HeliumDecoderLayer,
"attentions": HeliumAttention,
}
def _init_weights(self, module):

View File

@ -29,7 +29,7 @@ from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, torch_int
from ...utils import TransformersKwargs, auto_docstring, torch_int
from .configuration_mlcd import MLCDVisionConfig
@ -370,7 +370,6 @@ class MLCDEncoder(nn.Module):
self.layers = nn.ModuleList([MLCDEncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
@can_return_tuple
def forward(
self,
inputs_embeds: torch.FloatTensor,

View File

@ -298,8 +298,8 @@ class OlmoPreTrainedModel(PreTrainedModel):
_supports_static_cache = True
_supports_attention_backend = True
_can_record_outputs: dict[str, tuple[nn.Module, int]] = {
"hidden_states": (OlmoDecoderLayer, 0),
"attentions": (OlmoAttention, 1),
"hidden_states": OlmoDecoderLayer,
"attentions": OlmoAttention,
}
def _init_weights(self, module):

View File

@ -303,8 +303,8 @@ class Olmo2PreTrainedModel(PreTrainedModel):
_supports_static_cache = True
_supports_attention_backend = True
_can_record_outputs: dict[str, tuple[nn.Module, int]] = {
"hidden_states": (Olmo2DecoderLayer, 0),
"attentions": (Olmo2Attention, 1),
"hidden_states": Olmo2DecoderLayer,
"attentions": Olmo2Attention,
}
def _init_weights(self, module):

View File

@ -303,8 +303,8 @@ class PhiPreTrainedModel(PreTrainedModel):
_supports_static_cache = True
_supports_attention_backend = True
_can_record_outputs: dict[str, tuple[nn.Module, int]] = {
"hidden_states": (PhiDecoderLayer, 0),
"attentions": (PhiAttention, 1),
"hidden_states": PhiDecoderLayer,
"attentions": PhiAttention,
}
def _init_weights(self, module):

View File

@ -267,8 +267,8 @@ class Qwen2PreTrainedModel(PreTrainedModel):
_supports_static_cache = True
_supports_attention_backend = True
_can_record_outputs: dict[str, tuple[nn.Module, int]] = {
"hidden_states": (Qwen2DecoderLayer, 0),
"attentions": (Qwen2Attention, 1),
"hidden_states": Qwen2DecoderLayer,
"attentions": Qwen2Attention,
}
def _init_weights(self, module):

View File

@ -293,8 +293,8 @@ class Qwen3PreTrainedModel(PreTrainedModel):
_supports_static_cache = True
_supports_attention_backend = True
_can_record_outputs: dict[str, tuple[nn.Module, int]] = {
"hidden_states": (Qwen3DecoderLayer, 0),
"attentions": (Qwen3Attention, 1),
"hidden_states": Qwen3DecoderLayer,
"attentions": Qwen3Attention,
}
def _init_weights(self, module):

View File

@ -30,13 +30,13 @@ from torch import Tensor, nn
from transformers.modeling_outputs import ModelOutput
from transformers.modeling_utils import PreTrainedModel
from transformers.utils.generic import TransformersKwargs, can_return_tuple, check_model_inputs
from transformers.utils.generic import OutputRecorder, TransformersKwargs, check_model_inputs
from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput
from ...processing_utils import Unpack
from ...utils import auto_docstring, logging
from ...utils import auto_docstring, can_return_tuple, logging
from .configuration_sam_hq import SamHQConfig, SamHQMaskDecoderConfig, SamHQPromptEncoderConfig, SamHQVisionConfig
@ -474,7 +474,10 @@ class SamHQVisionNeck(nn.Module):
class SamHQVisionEncoder(PreTrainedModel):
_can_record_outputs = {"hidden_states": (SamHQVisionLayer, 0), "vision_attentions": (SamHQVisionAttention, 1)}
_can_record_outputs = {
"hidden_states": OutputRecorder(SamHQVisionLayer),
"vision_attentions": OutputRecorder(SamHQVisionAttention, index=1),
}
def __init__(self, config: SamHQVisionConfig):
super().__init__(config)
@ -837,7 +840,9 @@ class SamHQFeedForward(nn.Module):
class SamHQMaskDecoder(PreTrainedModel):
_can_record_outputs = {"mask_decoder_attentions": (SamHQVisionAttention, 1, "transformer")}
_can_record_outputs = {
"mask_decoder_attentions": OutputRecorder(SamHQVisionAttention, index=1, layer_name="transformer")
}
def __init__(self, config: SamHQMaskDecoderConfig):
super().__init__(config)

View File

@ -297,8 +297,8 @@ class SmolLM3PreTrainedModel(PreTrainedModel):
_supports_static_cache = True
_supports_attention_backend = True
_can_record_outputs: dict[str, tuple[nn.Module, int]] = {
"hidden_states": (SmolLM3DecoderLayer, 0),
"attentions": (SmolLM3Attention, 1),
"hidden_states": SmolLM3DecoderLayer,
"attentions": SmolLM3Attention,
}
def _init_weights(self, module):

View File

@ -590,8 +590,8 @@ class T5GemmaPreTrainedModel(PreTrainedModel):
_supports_static_cache = True
_supports_attention_backend = True
_can_record_outputs: dict[str, tuple[nn.Module, int]] = {
"hidden_states": (T5GemmaDecoderLayer, 0),
"attentions": (T5GemmaAttention, 1),
"hidden_states": T5GemmaDecoderLayer,
"attentions": T5GemmaAttention,
}
def _init_weights(self, module):