[Mamba] from pretrained issue with self.embeddings (#29851)

* nit

* update

* oups

* Update src/transformers/models/mamba/modeling_mamba.py

Co-authored-by: Lysandre Debut <hi@lysand.re>

---------

Co-authored-by: Lysandre Debut <hi@lysand.re>
This commit is contained in:
Arthur 2024-03-28 21:54:51 +09:00 committed by GitHub
parent 441de62f49
commit e677479c81
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -501,8 +501,15 @@ class MambaModel(MambaPreTrainedModel):
self.gradient_checkpointing = False
self.norm_f = MambaRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
# Initialize weights and apply final processing
self._register_load_state_dict_pre_hook(self.load_hook)
self.post_init()
def load_hook(self, state_dict, prefix, *args):
for k in state_dict:
if "embedding." in k:
state_dict[k.replace("embedding.", "embeddings.")] = state_dict.pop(k)
break
def get_input_embeddings(self):
return self.embeddings