Merge remote-tracking branch 'upstream/master' into xlmr

This commit is contained in:
Stefan Schweter 2019-12-18 11:05:16 +01:00
commit cce3089b65
11 changed files with 58 additions and 28 deletions

View File

@ -79,6 +79,14 @@ Here is the full list of the currently provided pretrained models together with
| | ``bert-base-japanese-char-whole-word-masking`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. |
| | | | Trained on Japanese text using Whole-Word-Masking. Text is tokenized into characters. |
| | | (see `details on cl-tohoku repository <https://github.com/cl-tohoku/bert-japanese>`__). |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``bert-base-finnish-cased-v1`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. |
| | | | Trained on cased Finnish text. |
| | | (see `details on turkunlp.org <http://turkunlp.org/FinBERT/>`__). |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``bert-base-finnish-uncased-v1`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. |
| | | | Trained on uncased Finnish text. |
| | | (see `details on turkunlp.org <http://turkunlp.org/FinBERT/>`__). |
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| GPT | ``openai-gpt`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. |
| | | | OpenAI GPT English model |

View File

@ -5,7 +5,7 @@ boto3
# Used for downloading models over HTTP
requests
# For OpenAI GPT
regex
regex != 2019.12.17
# For XLNet
sentencepiece
# For XLM

View File

@ -59,7 +59,7 @@ setup(
'boto3',
'requests',
'tqdm',
'regex',
'regex != 2019.12.17',
'sentencepiece',
'sacremoses'],
entry_points={

View File

@ -85,7 +85,7 @@ class XxxTokenizer(PreTrainedTokenizer):
Args:
vocab_file: Path to a one-wordpiece-per-line vocabulary file
do_lower_case: Whether to lower case the input. Only has an effect when do_wordpiece_only=False
do_lower_case: Whether to lower case the input. Only has an effect when do_basic_tokenize=True
"""
vocab_files_names = VOCAB_FILES_NAMES

View File

@ -45,7 +45,9 @@ BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'bert-base-japanese': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-config.json",
'bert-base-japanese-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking-config.json",
'bert-base-japanese-char': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-config.json",
'bert-base-japanese-char-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking-config.json"
'bert-base-japanese-char-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking-config.json",
'bert-base-finnish-cased-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/config.json",
'bert-base-finnish-uncased-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/config.json",
}

View File

@ -20,6 +20,13 @@ import argparse
import logging
import numpy as np
import torch
import pathlib
import fairseq
from packaging import version
if version.parse(fairseq.__version__) < version.parse("0.9.0"):
raise Exception("requires fairseq >= 0.9.0")
from fairseq.models.roberta import RobertaModel as FairseqRobertaModel
from fairseq.modules import TransformerSentenceEncoderLayer
@ -45,8 +52,9 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
"""
roberta = FairseqRobertaModel.from_pretrained(roberta_checkpoint_path)
roberta.eval() # disable dropout
roberta_sent_encoder = roberta.model.decoder.sentence_encoder
config = BertConfig(
vocab_size=50265,
vocab_size=roberta_sent_encoder.embed_tokens.num_embeddings,
hidden_size=roberta.args.encoder_embed_dim,
num_hidden_layers=roberta.args.encoder_layers,
num_attention_heads=roberta.args.encoder_attention_heads,
@ -64,7 +72,6 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
# Now let's copy all the weights.
# Embeddings
roberta_sent_encoder = roberta.model.decoder.sentence_encoder
model.roberta.embeddings.word_embeddings.weight = roberta_sent_encoder.embed_tokens.weight
model.roberta.embeddings.position_embeddings.weight = roberta_sent_encoder.embed_positions.weight
model.roberta.embeddings.token_type_embeddings.weight.data = torch.zeros_like(model.roberta.embeddings.token_type_embeddings.weight) # just zero them out b/c RoBERTa doesn't use them.
@ -79,15 +86,18 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
### self attention
self_attn: BertSelfAttention = layer.attention.self
assert(
roberta_layer.self_attn.in_proj_weight.shape == torch.Size((3 * config.hidden_size, config.hidden_size))
roberta_layer.self_attn.k_proj.weight.data.shape == \
roberta_layer.self_attn.q_proj.weight.data.shape == \
roberta_layer.self_attn.v_proj.weight.data.shape == \
torch.Size((config.hidden_size, config.hidden_size))
)
# we use three distinct linear layers so we split the source layer here.
self_attn.query.weight.data = roberta_layer.self_attn.in_proj_weight[:config.hidden_size, :]
self_attn.query.bias.data = roberta_layer.self_attn.in_proj_bias[:config.hidden_size]
self_attn.key.weight.data = roberta_layer.self_attn.in_proj_weight[config.hidden_size:2*config.hidden_size, :]
self_attn.key.bias.data = roberta_layer.self_attn.in_proj_bias[config.hidden_size:2*config.hidden_size]
self_attn.value.weight.data = roberta_layer.self_attn.in_proj_weight[2*config.hidden_size:, :]
self_attn.value.bias.data = roberta_layer.self_attn.in_proj_bias[2*config.hidden_size:]
self_attn.query.weight.data = roberta_layer.self_attn.q_proj.weight
self_attn.query.bias.data = roberta_layer.self_attn.q_proj.bias
self_attn.key.weight.data = roberta_layer.self_attn.k_proj.weight
self_attn.key.bias.data = roberta_layer.self_attn.k_proj.bias
self_attn.value.weight.data = roberta_layer.self_attn.v_proj.weight
self_attn.value.bias.data = roberta_layer.self_attn.v_proj.bias
### self-attention output
self_output: BertSelfOutput = layer.attention.output
@ -151,6 +161,7 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
if not success:
raise Exception("Something went wRoNg")
pathlib.Path(pytorch_dump_folder_path).mkdir(parents=True, exist_ok=True)
print(f"Saving model to {pytorch_dump_folder_path}")
model.save_pretrained(pytorch_dump_folder_path)

View File

@ -26,14 +26,6 @@ from contextlib import contextmanager
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
try:
import tensorflow as tf
assert hasattr(tf, '__version__') and int(tf.__version__[0]) >= 2
_tf_available = True # pylint: disable=invalid-name
logger.info("TensorFlow version {} available.".format(tf.__version__))
except (ImportError, AssertionError):
_tf_available = False # pylint: disable=invalid-name
try:
import torch
_torch_available = True # pylint: disable=invalid-name
@ -41,6 +33,13 @@ try:
except ImportError:
_torch_available = False # pylint: disable=invalid-name
try:
import tensorflow as tf
assert hasattr(tf, '__version__') and int(tf.__version__[0]) >= 2
_tf_available = True # pylint: disable=invalid-name
logger.info("TensorFlow version {} available.".format(tf.__version__))
except (ImportError, AssertionError):
_tf_available = False # pylint: disable=invalid-name
try:
from torch.hub import _get_torch_home

View File

@ -51,7 +51,9 @@ BERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
'bert-base-japanese': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-pytorch_model.bin",
'bert-base-japanese-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking-pytorch_model.bin",
'bert-base-japanese-char': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-pytorch_model.bin",
'bert-base-japanese-char-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking-pytorch_model.bin"
'bert-base-japanese-char-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking-pytorch_model.bin",
'bert-base-finnish-cased-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/pytorch_model.bin",
'bert-base-finnish-uncased-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/pytorch_model.bin",
}

View File

@ -51,7 +51,9 @@ TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
'bert-base-japanese': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-tf_model.h5",
'bert-base-japanese-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking-tf_model.h5",
'bert-base-japanese-char': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-tf_model.h5",
'bert-base-japanese-char-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking-tf_model.h5"
'bert-base-japanese-char-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking-tf_model.h5",
'bert-base-finnish-cased-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/tf_model.h5",
'bert-base-finnish-uncased-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/tf_model.h5",
}

View File

@ -46,6 +46,8 @@ PRETRAINED_VOCAB_FILES_MAP = {
'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt",
'bert-base-german-dbmdz-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-vocab.txt",
'bert-base-german-dbmdz-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-vocab.txt",
'bert-base-finnish-cased-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/vocab.txt",
'bert-base-finnish-uncased-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/vocab.txt",
}
}
@ -65,6 +67,8 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'bert-base-cased-finetuned-mrpc': 512,
'bert-base-german-dbmdz-cased': 512,
'bert-base-german-dbmdz-uncased': 512,
'bert-base-finnish-cased-v1': 512,
'bert-base-finnish-uncased-v1': 512,
}
PRETRAINED_INIT_CONFIGURATION = {
@ -83,6 +87,8 @@ PRETRAINED_INIT_CONFIGURATION = {
'bert-base-cased-finetuned-mrpc': {'do_lower_case': False},
'bert-base-german-dbmdz-cased': {'do_lower_case': False},
'bert-base-german-dbmdz-uncased': {'do_lower_case': True},
'bert-base-finnish-cased-v1': {'do_lower_case': False},
'bert-base-finnish-uncased-v1': {'do_lower_case': True},
}
@ -113,12 +119,12 @@ class BertTokenizer(PreTrainedTokenizer):
Args:
vocab_file: Path to a one-wordpiece-per-line vocabulary file
do_lower_case: Whether to lower case the input. Only has an effect when do_wordpiece_only=False
do_lower_case: Whether to lower case the input. Only has an effect when do_basic_tokenize=True
do_basic_tokenize: Whether to do basic tokenization before wordpiece.
max_len: An artificial maximum length to truncate tokenized sequences to; Effective maximum length is always the
minimum of this value (if specified) and the underlying BERT model's sequence length.
never_split: List of tokens which will never be split during tokenization. Only has an effect when
do_wordpiece_only=False
do_basic_tokenize=True
"""
vocab_files_names = VOCAB_FILES_NAMES

View File

@ -53,12 +53,12 @@ class DistilBertTokenizer(BertTokenizer):
Args:
vocab_file: Path to a one-wordpiece-per-line vocabulary file
do_lower_case: Whether to lower case the input. Only has an effect when do_wordpiece_only=False
do_lower_case: Whether to lower case the input. Only has an effect when do_basic_tokenize=True
do_basic_tokenize: Whether to do basic tokenization before wordpiece.
max_len: An artificial maximum length to truncate tokenized sequences to; Effective maximum length is always the
minimum of this value (if specified) and the underlying BERT model's sequence length.
never_split: List of tokens which will never be split during tokenization. Only has an effect when
do_wordpiece_only=False
do_basic_tokenize=True
"""
vocab_files_names = VOCAB_FILES_NAMES