a little bit of magic

This commit is contained in:
Arthur 2025-07-03 15:55:57 +02:00
parent 1f559c676f
commit cd63172ced
3 changed files with 5 additions and 4 deletions

View File

@ -319,8 +319,8 @@ class LlamaPreTrainedModel(PreTrainedModel):
_supports_static_cache = True
_supports_attention_backend = True
_can_record_outputs: dict[str, tuple[nn.Module, int]] = {
"hidden_states": (LlamaDecoderLayer, 0),
"attentions": (LlamaAttention, 1),
"hidden_states": LlamaDecoderLayer,
"attentions": LlamaAttention,
}
def _init_weights(self, module):

View File

@ -21,7 +21,7 @@ from torch import nn
from transformers.modeling_outputs import ModelOutput
from transformers.modeling_utils import PreTrainedModel
from transformers.utils.generic import TransformersKwargs, can_return_tuple, check_model_inputs, OutputRecorder
from transformers.utils.generic import OutputRecorder, TransformersKwargs, check_model_inputs
from ...processing_utils import Unpack
from ...utils import auto_docstring, logging

View File

@ -26,7 +26,7 @@ from contextlib import ExitStack, contextmanager
from dataclasses import dataclass, fields, is_dataclass
from enum import Enum
from functools import partial, wraps
from typing import Any, Callable, ContextManager, Optional, Type, TypedDict, Union
from typing import Any, Callable, ContextManager, Optional, TypedDict
import numpy as np
from packaging import version
@ -1061,6 +1061,7 @@ def check_model_inputs(func):
else:
hook_fn = make_capture_fn(key, specs.index)
hooks.append(register_hook_if_needed(module, hook_fn))
outputs = func(self, *args, **kwargs)
for h in hooks:
if h is not None: