mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[Mamba
] from pretrained issue with self.embeddings
(#29851)
* nit * update * oups * Update src/transformers/models/mamba/modeling_mamba.py Co-authored-by: Lysandre Debut <hi@lysand.re> --------- Co-authored-by: Lysandre Debut <hi@lysand.re>
This commit is contained in:
parent
441de62f49
commit
e677479c81
@ -501,8 +501,15 @@ class MambaModel(MambaPreTrainedModel):
|
||||
self.gradient_checkpointing = False
|
||||
self.norm_f = MambaRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
# Initialize weights and apply final processing
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
self.post_init()
|
||||
|
||||
def load_hook(self, state_dict, prefix, *args):
|
||||
for k in state_dict:
|
||||
if "embedding." in k:
|
||||
state_dict[k.replace("embedding.", "embeddings.")] = state_dict.pop(k)
|
||||
break
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user