protect import

This commit is contained in:
Arthur 2025-07-03 16:35:51 +02:00
parent c5592be0ff
commit f6190cbf20
3 changed files with 6 additions and 6 deletions

View File

@ -300,9 +300,9 @@ class Starcoder2PreTrainedModel(PreTrainedModel):
_supports_quantized_cache = True
_supports_static_cache = True
_supports_attention_backend = True
_can_record_outputs: = {
"hidden_states": Starcoder2DecoderLayer,
"attentions": Starcoder2Attention,
_can_record_outputs: dict[str, tuple[nn.Module, int]] = {
"hidden_states": (Starcoder2DecoderLayer, 0),
"attentions": (Starcoder2Attention, 1),
}
def _init_weights(self, module):

View File

@ -687,8 +687,8 @@ def make_default_2d_attention_mask(
class T5GemmaEncoder(T5GemmaPreTrainedModel):
_can_record_outputs = {
"attentions": (T5GemmaSelfAttention, 1),
"hidden_states": (T5GemmaEncoderLayer, 1),
"attentions": T5GemmaSelfAttention,
"hidden_states": T5GemmaEncoderLayer,
}
def __init__(self, config):

View File

@ -982,7 +982,7 @@ class OutputRecorder:
layer_name (Optional[str]): Name of the submodule to target (if needed), e.g., "transformer.layer.3.attn".
"""
target_class: torch.nn.Module
target_class: "Type[torch.nn.Module]"
index: Optional[int] = 0
layer_name: Optional[str] = None