mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
a little bit of magic
This commit is contained in:
parent
1f559c676f
commit
cd63172ced
@ -319,8 +319,8 @@ class LlamaPreTrainedModel(PreTrainedModel):
|
||||
_supports_static_cache = True
|
||||
_supports_attention_backend = True
|
||||
_can_record_outputs: dict[str, tuple[nn.Module, int]] = {
|
||||
"hidden_states": (LlamaDecoderLayer, 0),
|
||||
"attentions": (LlamaAttention, 1),
|
||||
"hidden_states": LlamaDecoderLayer,
|
||||
"attentions": LlamaAttention,
|
||||
}
|
||||
|
||||
def _init_weights(self, module):
|
||||
|
@ -21,7 +21,7 @@ from torch import nn
|
||||
|
||||
from transformers.modeling_outputs import ModelOutput
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.utils.generic import TransformersKwargs, can_return_tuple, check_model_inputs, OutputRecorder
|
||||
from transformers.utils.generic import OutputRecorder, TransformersKwargs, check_model_inputs
|
||||
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import auto_docstring, logging
|
||||
|
@ -26,7 +26,7 @@ from contextlib import ExitStack, contextmanager
|
||||
from dataclasses import dataclass, fields, is_dataclass
|
||||
from enum import Enum
|
||||
from functools import partial, wraps
|
||||
from typing import Any, Callable, ContextManager, Optional, Type, TypedDict, Union
|
||||
from typing import Any, Callable, ContextManager, Optional, TypedDict
|
||||
|
||||
import numpy as np
|
||||
from packaging import version
|
||||
@ -1061,6 +1061,7 @@ def check_model_inputs(func):
|
||||
else:
|
||||
hook_fn = make_capture_fn(key, specs.index)
|
||||
hooks.append(register_hook_if_needed(module, hook_fn))
|
||||
|
||||
outputs = func(self, *args, **kwargs)
|
||||
for h in hooks:
|
||||
if h is not None:
|
||||
|
Loading…
Reference in New Issue
Block a user