mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-25 23:38:59 +06:00
Fixed XLM weights conversion script. Added 5 new checkpoints for XLM.
This commit is contained in:
parent
c82b74b996
commit
dee3e45b93
@ -23,7 +23,8 @@ from io import open
|
|||||||
import torch
|
import torch
|
||||||
import numpy
|
import numpy
|
||||||
|
|
||||||
from pytorch_transformers.modeling_xlm import (CONFIG_NAME, WEIGHTS_NAME, XLMConfig, XLMModel)
|
from pytorch_transformers.modeling_utils import CONFIG_NAME, WEIGHTS_NAME
|
||||||
|
from pytorch_transformers.modeling_xlm import (XLMConfig, XLMModel)
|
||||||
from pytorch_transformers.tokenization_xlm import VOCAB_FILES_NAMES
|
from pytorch_transformers.tokenization_xlm import VOCAB_FILES_NAMES
|
||||||
|
|
||||||
|
|
||||||
@ -37,7 +38,7 @@ def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_p
|
|||||||
config = dict((n, v) for n, v in config.items() if not isinstance(v, (torch.Tensor, numpy.ndarray)))
|
config = dict((n, v) for n, v in config.items() if not isinstance(v, (torch.Tensor, numpy.ndarray)))
|
||||||
|
|
||||||
vocab = chkpt['dico_word2id']
|
vocab = chkpt['dico_word2id']
|
||||||
vocab = dict((s + '</w>' if s.find('@@') == -1 and i > 13 else s.replace('@@', ''), i) for s, i in d.items())
|
vocab = dict((s + '</w>' if s.find('@@') == -1 and i > 13 else s.replace('@@', ''), i) for s, i in vocab.items())
|
||||||
|
|
||||||
# Save pytorch-model
|
# Save pytorch-model
|
||||||
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
|
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
|
||||||
|
@ -37,9 +37,19 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
XLM_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
XLM_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||||
'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-pytorch_model.bin",
|
'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-pytorch_model.bin",
|
||||||
|
'xlm-mlm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-pytorch_model.bin",
|
||||||
|
'xlm-mlm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-pytorch_model.bin",
|
||||||
|
'xlm-mlm-enro-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enro-1024-pytorch_model.bin",
|
||||||
|
'xlm-mlm-tlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-tlm-xnli15-1024-pytorch_model.bin",
|
||||||
|
'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-pytorch_model.bin",
|
||||||
}
|
}
|
||||||
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-config.json",
|
'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-config.json",
|
||||||
|
'xlm-mlm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-config.bin",
|
||||||
|
'xlm-mlm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-configl.bin",
|
||||||
|
'xlm-mlm-enro-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enro-1024-config.bin",
|
||||||
|
'xlm-mlm-tlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-tlm-xnli15-1024-config.bin",
|
||||||
|
'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-config.bin",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -36,10 +36,20 @@ PRETRAINED_VOCAB_FILES_MAP = {
|
|||||||
'vocab_file':
|
'vocab_file':
|
||||||
{
|
{
|
||||||
'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-vocab.json",
|
'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-vocab.json",
|
||||||
|
'xlm-mlm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-vocab.bin",
|
||||||
|
'xlm-mlm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-vocab.bin",
|
||||||
|
'xlm-mlm-enro-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enro-1024-vocab.bin",
|
||||||
|
'xlm-mlm-tlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-tlm-xnli15-1024-vocab.bin",
|
||||||
|
'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-vocab.bin",
|
||||||
},
|
},
|
||||||
'merges_file':
|
'merges_file':
|
||||||
{
|
{
|
||||||
'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-merges.txt",
|
'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-merges.txt",
|
||||||
|
'xlm-mlm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-merges.bin",
|
||||||
|
'xlm-mlm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-merges.bin",
|
||||||
|
'xlm-mlm-enro-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enro-1024-merges.bin",
|
||||||
|
'xlm-mlm-tlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-tlm-xnli15-1024-merges.bin",
|
||||||
|
'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-merges.bin",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user