From dee3e45b93e65e3de9cdf28f5ffbe148d91e361e Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Wed, 10 Jul 2019 19:04:21 -0400 Subject: [PATCH] Fixed XLM weights conversion script. Added 5 new checkpoints for XLM. --- .../convert_xlm_checkpoint_to_pytorch.py | 5 +++-- pytorch_transformers/modeling_xlm.py | 10 ++++++++++ pytorch_transformers/tokenization_xlm.py | 10 ++++++++++ 3 files changed, 23 insertions(+), 2 deletions(-) diff --git a/pytorch_transformers/convert_xlm_checkpoint_to_pytorch.py b/pytorch_transformers/convert_xlm_checkpoint_to_pytorch.py index e5815252f12..416f1bc16d3 100755 --- a/pytorch_transformers/convert_xlm_checkpoint_to_pytorch.py +++ b/pytorch_transformers/convert_xlm_checkpoint_to_pytorch.py @@ -23,7 +23,8 @@ from io import open import torch import numpy -from pytorch_transformers.modeling_xlm import (CONFIG_NAME, WEIGHTS_NAME, XLMConfig, XLMModel) +from pytorch_transformers.modeling_utils import CONFIG_NAME, WEIGHTS_NAME +from pytorch_transformers.modeling_xlm import (XLMConfig, XLMModel) from pytorch_transformers.tokenization_xlm import VOCAB_FILES_NAMES @@ -37,7 +38,7 @@ def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_p config = dict((n, v) for n, v in config.items() if not isinstance(v, (torch.Tensor, numpy.ndarray))) vocab = chkpt['dico_word2id'] - vocab = dict((s + '' if s.find('@@') == -1 and i > 13 else s.replace('@@', ''), i) for s, i in d.items()) + vocab = dict((s + '' if s.find('@@') == -1 and i > 13 else s.replace('@@', ''), i) for s, i in vocab.items()) # Save pytorch-model pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME diff --git a/pytorch_transformers/modeling_xlm.py b/pytorch_transformers/modeling_xlm.py index 72ec6397a02..60dc1a7bed4 100644 --- a/pytorch_transformers/modeling_xlm.py +++ b/pytorch_transformers/modeling_xlm.py @@ -37,9 +37,19 @@ logger = logging.getLogger(__name__) XLM_PRETRAINED_MODEL_ARCHIVE_MAP = { 'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-pytorch_model.bin", + 'xlm-mlm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-pytorch_model.bin", + 'xlm-mlm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-pytorch_model.bin", + 'xlm-mlm-enro-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enro-1024-pytorch_model.bin", + 'xlm-mlm-tlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-tlm-xnli15-1024-pytorch_model.bin", + 'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-pytorch_model.bin", } XLM_PRETRAINED_CONFIG_ARCHIVE_MAP = { 'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-config.json", + 'xlm-mlm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-config.bin", + 'xlm-mlm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-configl.bin", + 'xlm-mlm-enro-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enro-1024-config.bin", + 'xlm-mlm-tlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-tlm-xnli15-1024-config.bin", + 'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-config.bin", } diff --git a/pytorch_transformers/tokenization_xlm.py b/pytorch_transformers/tokenization_xlm.py index 74bc56f350d..a81567f21c9 100644 --- a/pytorch_transformers/tokenization_xlm.py +++ b/pytorch_transformers/tokenization_xlm.py @@ -36,10 +36,20 @@ PRETRAINED_VOCAB_FILES_MAP = { 'vocab_file': { 'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-vocab.json", + 'xlm-mlm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-vocab.bin", + 'xlm-mlm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-vocab.bin", + 'xlm-mlm-enro-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enro-1024-vocab.bin", + 'xlm-mlm-tlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-tlm-xnli15-1024-vocab.bin", + 'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-vocab.bin", }, 'merges_file': { 'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-merges.txt", + 'xlm-mlm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-merges.bin", + 'xlm-mlm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-merges.bin", + 'xlm-mlm-enro-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enro-1024-merges.bin", + 'xlm-mlm-tlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-tlm-xnli15-1024-merges.bin", + 'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-merges.bin", }, }