[emu3] fix conversion script (#38297)

* fix conversion script and update weights

* fixup

* remove commented line
This commit is contained in:
Raushan Turganbay 2025-05-23 09:49:56 +02:00 committed by GitHub
parent 2b585419b4
commit b01984a51d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 16 additions and 3 deletions

View File

@ -211,14 +211,13 @@ def convert_tiktoken(tokenizer, output_dir):
KEYS_TO_MODIFY_MAPPING = {
"^model": "model.text_model",
"^encoder": "model.vqmodel.encoder",
"^decoder": "model.vqmodel.decoder",
"^post_quant_conv": "model.vqmodel.post_quant_conv",
"^quant_conv": "model.vqmodel.quant_conv",
"^quantize": "model.vqmodel.quantize",
"^model": "text_model.model",
r"lm_head\.weight": "text_model.lm_head.weight",
r"^text_model\.model\.vqmodel": "vqmodel",
r"lm_head\.weight": "lm_head.weight",
# rename QKV proj for the VQ-VAE model because we use SiglipAttention
r"\.q\.": ".q_proj.",
r"\.k\.": ".k_proj.",

View File

@ -1598,6 +1598,13 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
def vqmodel(self):
return self.model.vqmodel
@property
def vocabulary_mapping(self):
return self.model.vocabulary_mapping
def decode_image_tokens(self, **kwargs):
return self.model.decode_image_tokens(**kwargs)
@can_return_tuple
@auto_docstring
def forward(

View File

@ -1077,6 +1077,13 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
def vqmodel(self):
return self.model.vqmodel
@property
def vocabulary_mapping(self):
return self.model.vocabulary_mapping
def decode_image_tokens(self, **kwargs):
return self.model.decode_image_tokens(**kwargs)
@can_return_tuple
@auto_docstring
def forward(