mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
move the fix a bit
This commit is contained in:
parent
00afce9837
commit
3ac6c52f34
@ -1103,10 +1103,6 @@ class Emu3PreTrainedModel(PreTrainedModel):
|
|||||||
_supports_param_buffer_assignment = False
|
_supports_param_buffer_assignment = False
|
||||||
_supports_flex_attn = True
|
_supports_flex_attn = True
|
||||||
_supports_attention_backend = True
|
_supports_attention_backend = True
|
||||||
_can_record_outputs: dict[str, tuple[nn.Module, int]] = {
|
|
||||||
"hidden_states": (Emu3DecoderLayer, 0),
|
|
||||||
"attentions": (Emu3Attention, 1),
|
|
||||||
}
|
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
std = self.config.get_text_config().initializer_range
|
std = self.config.get_text_config().initializer_range
|
||||||
@ -1158,6 +1154,11 @@ class Emu3RotaryEmbedding(nn.Module):
|
|||||||
|
|
||||||
@auto_docstring
|
@auto_docstring
|
||||||
class Emu3TextModel(Emu3PreTrainedModel):
|
class Emu3TextModel(Emu3PreTrainedModel):
|
||||||
|
_can_record_outputs: dict[str, tuple[nn.Module, int]] = {
|
||||||
|
"hidden_states": (Emu3DecoderLayer, 0),
|
||||||
|
"attentions": (Emu3Attention, 1),
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(self, config: Emu3Config):
|
def __init__(self, config: Emu3Config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.padding_idx = config.pad_token_id
|
self.padding_idx = config.pad_token_id
|
||||||
|
@ -846,10 +846,6 @@ class Emu3PreTrainedModel(ChameleonPreTrainedModel, Emu3VQVAE):
|
|||||||
]
|
]
|
||||||
_supports_flex_attn = True
|
_supports_flex_attn = True
|
||||||
_supports_attention_backend = True
|
_supports_attention_backend = True
|
||||||
_can_record_outputs: dict[str, tuple[nn.Module, int]] = {
|
|
||||||
"hidden_states": (Emu3DecoderLayer, 0),
|
|
||||||
"attentions": (Emu3Attention, 1),
|
|
||||||
}
|
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
std = self.config.get_text_config().initializer_range
|
std = self.config.get_text_config().initializer_range
|
||||||
@ -866,6 +862,11 @@ class Emu3PreTrainedModel(ChameleonPreTrainedModel, Emu3VQVAE):
|
|||||||
|
|
||||||
|
|
||||||
class Emu3TextModel(LlamaModel, Emu3PreTrainedModel):
|
class Emu3TextModel(LlamaModel, Emu3PreTrainedModel):
|
||||||
|
_can_record_outputs: dict[str, tuple[nn.Module, int]] = {
|
||||||
|
"hidden_states": (Emu3DecoderLayer, 0),
|
||||||
|
"attentions": (Emu3Attention, 1),
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(self, config: Emu3Config):
|
def __init__(self, config: Emu3Config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
|
Loading…
Reference in New Issue
Block a user