diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index cac1ee6ae4a..c334fb97d5e 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1925,6 +1925,25 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi - **is_parallelizable** (`bool`) -- A flag indicating whether this model supports model parallelization. - **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP models, `pixel_values` for vision models and `input_values` for speech models). + - **can_record_outputs** (dict): Maps output names (e.g., "attentions", "hidden_states") + to either: + - A module class (e.g., `LlamaDecoderLayer`), using default index conventions: + * index=0 for "hidden_states" + * index=1 for "attentions" + - Or an `OutputRecorder(...)` with `target_class`, optional `index`, and `layer_name`. + + Examples: + These two are equivalent: + + _can_record_outputs = { + "attentions": LlamaAttention, + "hidden_states": LlamaDecoderLayer + } + + _can_record_outputs = { + "attentions": OutputRecorder(LlamaAttention, index=1), + "hidden_states": OutputRecorder(LlamaDecoderLayer, index=0) + } """ config_class = None diff --git a/src/transformers/models/arcee/modeling_arcee.py b/src/transformers/models/arcee/modeling_arcee.py index 9bd34ec30bb..e77691ebf3d 100644 --- a/src/transformers/models/arcee/modeling_arcee.py +++ b/src/transformers/models/arcee/modeling_arcee.py @@ -321,8 +321,8 @@ class ArceePreTrainedModel(PreTrainedModel): _supports_static_cache = True _supports_attention_backend = True _can_record_outputs: dict[str, tuple[nn.Module, int]] = { - "hidden_states": (ArceeDecoderLayer, 0), - "attentions": (ArceeAttention, 1), + "hidden_states": ArceeDecoderLayer, + "attentions": ArceeAttention, } def _init_weights(self, module): diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index 028129e8946..b5a66b35a06 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -665,8 +665,8 @@ class AriaPreTrainedModel(PreTrainedModel): _supports_static_cache = False # MoE models don't work with torch.compile (dynamic slicing) _supports_attention_backend = True _can_record_outputs: dict[str, tuple[nn.Module, int]] = { - "hidden_states": (AriaTextDecoderLayer, 0), - "attentions": (AriaTextAttention, 1), + "hidden_states": AriaTextDecoderLayer, + "attentions": AriaTextAttention, } def _init_weights(self, module): diff --git a/src/transformers/models/bitnet/modeling_bitnet.py b/src/transformers/models/bitnet/modeling_bitnet.py index 199a5daa321..88ea7fcc486 100644 --- a/src/transformers/models/bitnet/modeling_bitnet.py +++ b/src/transformers/models/bitnet/modeling_bitnet.py @@ -316,8 +316,8 @@ class BitNetPreTrainedModel(PreTrainedModel): _supports_static_cache = True _supports_attention_backend = True _can_record_outputs: dict[str, tuple[nn.Module, int]] = { - "hidden_states": (BitNetDecoderLayer, 0), - "attentions": (BitNetAttention, 1), + "hidden_states": BitNetDecoderLayer, + "attentions": BitNetAttention, } def _init_weights(self, module): diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 1c021ed7a5c..22d5a5d40e5 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -349,8 +349,8 @@ class CoherePreTrainedModel(PreTrainedModel): _supports_static_cache = True _supports_attention_backend = True _can_record_outputs: dict[str, tuple[nn.Module, int]] = { - "hidden_states": (CohereDecoderLayer, 0), - "attentions": (CohereAttention, 1), + "hidden_states": CohereDecoderLayer, + "attentions": CohereAttention, } def _init_weights(self, module): diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py index ec784480802..01bcd161dfa 100644 --- a/src/transformers/models/cohere2/modeling_cohere2.py +++ b/src/transformers/models/cohere2/modeling_cohere2.py @@ -326,8 +326,8 @@ class Cohere2PreTrainedModel(PreTrainedModel): _supports_static_cache = True _supports_attention_backend = True _can_record_outputs: dict[str, tuple[nn.Module, int]] = { - "hidden_states": (Cohere2DecoderLayer, 0), - "attentions": (Cohere2Attention, 1), + "hidden_states": Cohere2DecoderLayer, + "attentions": Cohere2Attention, } def _init_weights(self, module): diff --git a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py index 20e3efbb985..77d683e4f82 100644 --- a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py @@ -502,8 +502,8 @@ class DeepseekV3PreTrainedModel(PreTrainedModel): _supports_static_cache = True _supports_attention_backend = True _can_record_outputs: dict[str, tuple[nn.Module, int]] = { - "hidden_states": (DeepseekV3DecoderLayer, 0), - "attentions": (DeepseekV3Attention, 1), + "hidden_states": DeepseekV3DecoderLayer, + "attentions": DeepseekV3Attention, } def _init_weights(self, module): diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index 524bb16ee52..6e1cfea95ba 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -553,8 +553,8 @@ class DiffLlamaPreTrainedModel(PreTrainedModel): _supports_static_cache = True _supports_attention_backend = False _can_record_outputs: dict[str, tuple[nn.Module, int]] = { - "hidden_states": (DiffLlamaDecoderLayer, 0), - "attentions": (DiffLlamaAttention, 1), + "hidden_states": DiffLlamaDecoderLayer, + "attentions": DiffLlamaAttention, } def _init_weights(self, module): diff --git a/src/transformers/models/dots1/modeling_dots1.py b/src/transformers/models/dots1/modeling_dots1.py index 2fbae710435..654c13b349a 100644 --- a/src/transformers/models/dots1/modeling_dots1.py +++ b/src/transformers/models/dots1/modeling_dots1.py @@ -422,8 +422,8 @@ class Dots1PreTrainedModel(PreTrainedModel): _supports_static_cache = True _supports_attention_backend = True _can_record_outputs: dict[str, tuple[nn.Module, int]] = { - "hidden_states": (Dots1DecoderLayer, 0), - "attentions": (Dots1Attention, 1), + "hidden_states": Dots1DecoderLayer, + "attentions": Dots1Attention, } def _init_weights(self, module): diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 88719cdbddb..8ca1a9de100 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -318,8 +318,8 @@ class GemmaPreTrainedModel(PreTrainedModel): _supports_static_cache = True _supports_attention_backend = True _can_record_outputs: dict[str, tuple[nn.Module, int]] = { - "hidden_states": (GemmaDecoderLayer, 0), - "attentions": (GemmaAttention, 1), + "hidden_states": GemmaDecoderLayer, + "attentions": GemmaAttention, } def _init_weights(self, module): diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index fdd6071d86b..f85aa3b4e18 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -348,8 +348,8 @@ class Gemma2PreTrainedModel(PreTrainedModel): _supports_static_cache = True _supports_attention_backend = True _can_record_outputs: dict[str, tuple[nn.Module, int]] = { - "hidden_states": (Gemma2DecoderLayer, 0), - "attentions": (Gemma2Attention, 1), + "hidden_states": Gemma2DecoderLayer, + "attentions": Gemma2Attention, } def _init_weights(self, module): diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index 6a3918958ab..1dcf6d1e40a 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -438,8 +438,8 @@ class Gemma3PreTrainedModel(PreTrainedModel): _supports_static_cache = True _supports_attention_backend = True _can_record_outputs: dict[str, tuple[nn.Module, int]] = { - "hidden_states": (Gemma3DecoderLayer, 0), - "attentions": (Gemma3Attention, 1), + "hidden_states": Gemma3DecoderLayer, + "attentions": Gemma3Attention, } def _init_weights(self, module): diff --git a/src/transformers/models/gemma3n/modeling_gemma3n.py b/src/transformers/models/gemma3n/modeling_gemma3n.py index 3db4f2a970e..e0a577868e0 100644 --- a/src/transformers/models/gemma3n/modeling_gemma3n.py +++ b/src/transformers/models/gemma3n/modeling_gemma3n.py @@ -1494,8 +1494,8 @@ class Gemma3nPreTrainedModel(PreTrainedModel): _supports_static_cache = True _supports_attention_backend = True _can_record_outputs: dict[str, tuple[nn.Module, int]] = { - "hidden_states": (Gemma3nTextDecoderLayer, 0), - "attentions": (Gemma3nTextAttention, 1), + "hidden_states": Gemma3nTextDecoderLayer, + "attentions": Gemma3nTextAttention, } def _init_weights(self, module): diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 94b33a06135..9512a057464 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -335,8 +335,8 @@ class GlmPreTrainedModel(PreTrainedModel): _supports_static_cache = True _supports_attention_backend = True _can_record_outputs: dict[str, tuple[nn.Module, int]] = { - "hidden_states": (GlmDecoderLayer, 0), - "attentions": (GlmAttention, 1), + "hidden_states": GlmDecoderLayer, + "attentions": GlmAttention, } def _init_weights(self, module): diff --git a/src/transformers/models/glm4/modeling_glm4.py b/src/transformers/models/glm4/modeling_glm4.py index 8aeb91db42f..9d543446d70 100644 --- a/src/transformers/models/glm4/modeling_glm4.py +++ b/src/transformers/models/glm4/modeling_glm4.py @@ -339,8 +339,8 @@ class Glm4PreTrainedModel(PreTrainedModel): _supports_static_cache = True _supports_attention_backend = True _can_record_outputs: dict[str, tuple[nn.Module, int]] = { - "hidden_states": (Glm4DecoderLayer, 0), - "attentions": (Glm4Attention, 1), + "hidden_states": Glm4DecoderLayer, + "attentions": Glm4Attention, } def _init_weights(self, module): diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 0aa2ed906fa..d93ee4666ea 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -368,8 +368,8 @@ class GPTNeoXPreTrainedModel(PreTrainedModel): _supports_static_cache = True _supports_attention_backend = True _can_record_outputs: dict[str, tuple[nn.Module, int]] = { - "hidden_states": (GPTNeoXDecoderLayer, 0), - "attentions": (GPTNeoXAttention, 1), + "hidden_states": GPTNeoXDecoderLayer, + "attentions": GPTNeoXAttention, } _keys_to_ignore_on_load_unexpected = [r"attention.bias", r"attention.masked_bias"] diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index e3a772f1393..3b217b77c6e 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -313,8 +313,8 @@ class GranitePreTrainedModel(PreTrainedModel): _supports_static_cache = True _supports_attention_backend = True _can_record_outputs: dict[str, tuple[nn.Module, int]] = { - "hidden_states": (GraniteDecoderLayer, 0), - "attentions": (GraniteAttention, 1), + "hidden_states": GraniteDecoderLayer, + "attentions": GraniteAttention, } def _init_weights(self, module): diff --git a/src/transformers/models/helium/modeling_helium.py b/src/transformers/models/helium/modeling_helium.py index 36df78fc5a8..78654cbf48a 100644 --- a/src/transformers/models/helium/modeling_helium.py +++ b/src/transformers/models/helium/modeling_helium.py @@ -320,8 +320,8 @@ class HeliumPreTrainedModel(PreTrainedModel): _supports_static_cache = True _supports_attention_backend = True _can_record_outputs: dict[str, tuple[nn.Module, int]] = { - "hidden_states": (HeliumDecoderLayer, 0), - "attentions": (HeliumAttention, 1), + "hidden_states": HeliumDecoderLayer, + "attentions": HeliumAttention, } def _init_weights(self, module): diff --git a/src/transformers/models/mlcd/modeling_mlcd.py b/src/transformers/models/mlcd/modeling_mlcd.py index e6c272b5f6e..23c212d2d36 100644 --- a/src/transformers/models/mlcd/modeling_mlcd.py +++ b/src/transformers/models/mlcd/modeling_mlcd.py @@ -29,7 +29,7 @@ from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, torch_int +from ...utils import TransformersKwargs, auto_docstring, torch_int from .configuration_mlcd import MLCDVisionConfig @@ -370,7 +370,6 @@ class MLCDEncoder(nn.Module): self.layers = nn.ModuleList([MLCDEncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False - @can_return_tuple def forward( self, inputs_embeds: torch.FloatTensor, diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 4a7b217f6ca..a51c9ef2306 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -298,8 +298,8 @@ class OlmoPreTrainedModel(PreTrainedModel): _supports_static_cache = True _supports_attention_backend = True _can_record_outputs: dict[str, tuple[nn.Module, int]] = { - "hidden_states": (OlmoDecoderLayer, 0), - "attentions": (OlmoAttention, 1), + "hidden_states": OlmoDecoderLayer, + "attentions": OlmoAttention, } def _init_weights(self, module): diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index 63989e8e574..1b52e8dd638 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -303,8 +303,8 @@ class Olmo2PreTrainedModel(PreTrainedModel): _supports_static_cache = True _supports_attention_backend = True _can_record_outputs: dict[str, tuple[nn.Module, int]] = { - "hidden_states": (Olmo2DecoderLayer, 0), - "attentions": (Olmo2Attention, 1), + "hidden_states": Olmo2DecoderLayer, + "attentions": Olmo2Attention, } def _init_weights(self, module): diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index ddd8c85f7ab..5a2678b196c 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -303,8 +303,8 @@ class PhiPreTrainedModel(PreTrainedModel): _supports_static_cache = True _supports_attention_backend = True _can_record_outputs: dict[str, tuple[nn.Module, int]] = { - "hidden_states": (PhiDecoderLayer, 0), - "attentions": (PhiAttention, 1), + "hidden_states": PhiDecoderLayer, + "attentions": PhiAttention, } def _init_weights(self, module): diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index b34cd86d98c..27d571227a4 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -267,8 +267,8 @@ class Qwen2PreTrainedModel(PreTrainedModel): _supports_static_cache = True _supports_attention_backend = True _can_record_outputs: dict[str, tuple[nn.Module, int]] = { - "hidden_states": (Qwen2DecoderLayer, 0), - "attentions": (Qwen2Attention, 1), + "hidden_states": Qwen2DecoderLayer, + "attentions": Qwen2Attention, } def _init_weights(self, module): diff --git a/src/transformers/models/qwen3/modeling_qwen3.py b/src/transformers/models/qwen3/modeling_qwen3.py index c2327e646ab..d842508962a 100644 --- a/src/transformers/models/qwen3/modeling_qwen3.py +++ b/src/transformers/models/qwen3/modeling_qwen3.py @@ -293,8 +293,8 @@ class Qwen3PreTrainedModel(PreTrainedModel): _supports_static_cache = True _supports_attention_backend = True _can_record_outputs: dict[str, tuple[nn.Module, int]] = { - "hidden_states": (Qwen3DecoderLayer, 0), - "attentions": (Qwen3Attention, 1), + "hidden_states": Qwen3DecoderLayer, + "attentions": Qwen3Attention, } def _init_weights(self, module): diff --git a/src/transformers/models/sam_hq/modeling_sam_hq.py b/src/transformers/models/sam_hq/modeling_sam_hq.py index 84bca47c34c..1908d5e1cd1 100644 --- a/src/transformers/models/sam_hq/modeling_sam_hq.py +++ b/src/transformers/models/sam_hq/modeling_sam_hq.py @@ -30,13 +30,13 @@ from torch import Tensor, nn from transformers.modeling_outputs import ModelOutput from transformers.modeling_utils import PreTrainedModel -from transformers.utils.generic import TransformersKwargs, can_return_tuple, check_model_inputs +from transformers.utils.generic import OutputRecorder, TransformersKwargs, check_model_inputs from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput from ...processing_utils import Unpack -from ...utils import auto_docstring, logging +from ...utils import auto_docstring, can_return_tuple, logging from .configuration_sam_hq import SamHQConfig, SamHQMaskDecoderConfig, SamHQPromptEncoderConfig, SamHQVisionConfig @@ -474,7 +474,10 @@ class SamHQVisionNeck(nn.Module): class SamHQVisionEncoder(PreTrainedModel): - _can_record_outputs = {"hidden_states": (SamHQVisionLayer, 0), "vision_attentions": (SamHQVisionAttention, 1)} + _can_record_outputs = { + "hidden_states": OutputRecorder(SamHQVisionLayer), + "vision_attentions": OutputRecorder(SamHQVisionAttention, index=1), + } def __init__(self, config: SamHQVisionConfig): super().__init__(config) @@ -837,7 +840,9 @@ class SamHQFeedForward(nn.Module): class SamHQMaskDecoder(PreTrainedModel): - _can_record_outputs = {"mask_decoder_attentions": (SamHQVisionAttention, 1, "transformer")} + _can_record_outputs = { + "mask_decoder_attentions": OutputRecorder(SamHQVisionAttention, index=1, layer_name="transformer") + } def __init__(self, config: SamHQMaskDecoderConfig): super().__init__(config) diff --git a/src/transformers/models/smollm3/modeling_smollm3.py b/src/transformers/models/smollm3/modeling_smollm3.py index bec9d32aa5e..e6d590ec74a 100644 --- a/src/transformers/models/smollm3/modeling_smollm3.py +++ b/src/transformers/models/smollm3/modeling_smollm3.py @@ -297,8 +297,8 @@ class SmolLM3PreTrainedModel(PreTrainedModel): _supports_static_cache = True _supports_attention_backend = True _can_record_outputs: dict[str, tuple[nn.Module, int]] = { - "hidden_states": (SmolLM3DecoderLayer, 0), - "attentions": (SmolLM3Attention, 1), + "hidden_states": SmolLM3DecoderLayer, + "attentions": SmolLM3Attention, } def _init_weights(self, module): diff --git a/src/transformers/models/t5gemma/modeling_t5gemma.py b/src/transformers/models/t5gemma/modeling_t5gemma.py index d68b44db932..3d8bd0522dd 100644 --- a/src/transformers/models/t5gemma/modeling_t5gemma.py +++ b/src/transformers/models/t5gemma/modeling_t5gemma.py @@ -590,8 +590,8 @@ class T5GemmaPreTrainedModel(PreTrainedModel): _supports_static_cache = True _supports_attention_backend = True _can_record_outputs: dict[str, tuple[nn.Module, int]] = { - "hidden_states": (T5GemmaDecoderLayer, 0), - "attentions": (T5GemmaAttention, 1), + "hidden_states": T5GemmaDecoderLayer, + "attentions": T5GemmaAttention, } def _init_weights(self, module):