move the fix a bit

This commit is contained in:
Arthur 2025-07-01 15:00:38 +02:00
parent 00afce9837
commit 3ac6c52f34
2 changed files with 10 additions and 8 deletions

View File

@ -1103,10 +1103,6 @@ class Emu3PreTrainedModel(PreTrainedModel):
_supports_param_buffer_assignment = False
_supports_flex_attn = True
_supports_attention_backend = True
_can_record_outputs: dict[str, tuple[nn.Module, int]] = {
"hidden_states": (Emu3DecoderLayer, 0),
"attentions": (Emu3Attention, 1),
}
def _init_weights(self, module):
std = self.config.get_text_config().initializer_range
@ -1158,6 +1154,11 @@ class Emu3RotaryEmbedding(nn.Module):
@auto_docstring
class Emu3TextModel(Emu3PreTrainedModel):
_can_record_outputs: dict[str, tuple[nn.Module, int]] = {
"hidden_states": (Emu3DecoderLayer, 0),
"attentions": (Emu3Attention, 1),
}
def __init__(self, config: Emu3Config):
super().__init__(config)
self.padding_idx = config.pad_token_id

View File

@ -846,10 +846,6 @@ class Emu3PreTrainedModel(ChameleonPreTrainedModel, Emu3VQVAE):
]
_supports_flex_attn = True
_supports_attention_backend = True
_can_record_outputs: dict[str, tuple[nn.Module, int]] = {
"hidden_states": (Emu3DecoderLayer, 0),
"attentions": (Emu3Attention, 1),
}
def _init_weights(self, module):
std = self.config.get_text_config().initializer_range
@ -866,6 +862,11 @@ class Emu3PreTrainedModel(ChameleonPreTrainedModel, Emu3VQVAE):
class Emu3TextModel(LlamaModel, Emu3PreTrainedModel):
_can_record_outputs: dict[str, tuple[nn.Module, int]] = {
"hidden_states": (Emu3DecoderLayer, 0),
"attentions": (Emu3Attention, 1),
}
def __init__(self, config: Emu3Config):
super().__init__(config)
self.layers = nn.ModuleList(