diff --git a/docs/source/pretrained_models.rst b/docs/source/pretrained_models.rst index c6b990f2139..7d037da34f0 100644 --- a/docs/source/pretrained_models.rst +++ b/docs/source/pretrained_models.rst @@ -79,6 +79,14 @@ Here is the full list of the currently provided pretrained models together with | | ``bert-base-japanese-char-whole-word-masking`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. | | | | | Trained on Japanese text using Whole-Word-Masking. Text is tokenized into characters. | | | | (see `details on cl-tohoku repository `__). | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``bert-base-finnish-cased-v1`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. | +| | | | Trained on cased Finnish text. | +| | | (see `details on turkunlp.org `__). | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``bert-base-finnish-uncased-v1`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. | +| | | | Trained on uncased Finnish text. | +| | | (see `details on turkunlp.org `__). | +-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ | GPT | ``openai-gpt`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. | | | | | OpenAI GPT English model | diff --git a/requirements.txt b/requirements.txt index 9c43abc6d76..32edee07125 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,7 @@ boto3 # Used for downloading models over HTTP requests # For OpenAI GPT -regex +regex != 2019.12.17 # For XLNet sentencepiece # For XLM diff --git a/setup.py b/setup.py index eacb5ecec0d..bf09a7d48a1 100644 --- a/setup.py +++ b/setup.py @@ -59,7 +59,7 @@ setup( 'boto3', 'requests', 'tqdm', - 'regex', + 'regex != 2019.12.17', 'sentencepiece', 'sacremoses'], entry_points={ diff --git a/templates/adding_a_new_model/tokenization_xxx.py b/templates/adding_a_new_model/tokenization_xxx.py index 3d6b4ad9df9..7a10a41e5ac 100644 --- a/templates/adding_a_new_model/tokenization_xxx.py +++ b/templates/adding_a_new_model/tokenization_xxx.py @@ -85,7 +85,7 @@ class XxxTokenizer(PreTrainedTokenizer): Args: vocab_file: Path to a one-wordpiece-per-line vocabulary file - do_lower_case: Whether to lower case the input. Only has an effect when do_wordpiece_only=False + do_lower_case: Whether to lower case the input. Only has an effect when do_basic_tokenize=True """ vocab_files_names = VOCAB_FILES_NAMES diff --git a/transformers/configuration_bert.py b/transformers/configuration_bert.py index 9072820bce7..7b495013ff4 100644 --- a/transformers/configuration_bert.py +++ b/transformers/configuration_bert.py @@ -45,7 +45,9 @@ BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { 'bert-base-japanese': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-config.json", 'bert-base-japanese-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking-config.json", 'bert-base-japanese-char': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-config.json", - 'bert-base-japanese-char-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking-config.json" + 'bert-base-japanese-char-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking-config.json", + 'bert-base-finnish-cased-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/config.json", + 'bert-base-finnish-uncased-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/config.json", } diff --git a/transformers/convert_roberta_original_pytorch_checkpoint_to_pytorch.py b/transformers/convert_roberta_original_pytorch_checkpoint_to_pytorch.py index b4dc1bb61bc..fedfc1ecb8a 100644 --- a/transformers/convert_roberta_original_pytorch_checkpoint_to_pytorch.py +++ b/transformers/convert_roberta_original_pytorch_checkpoint_to_pytorch.py @@ -20,6 +20,13 @@ import argparse import logging import numpy as np import torch +import pathlib + +import fairseq +from packaging import version + +if version.parse(fairseq.__version__) < version.parse("0.9.0"): + raise Exception("requires fairseq >= 0.9.0") from fairseq.models.roberta import RobertaModel as FairseqRobertaModel from fairseq.modules import TransformerSentenceEncoderLayer @@ -45,8 +52,9 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_ """ roberta = FairseqRobertaModel.from_pretrained(roberta_checkpoint_path) roberta.eval() # disable dropout + roberta_sent_encoder = roberta.model.decoder.sentence_encoder config = BertConfig( - vocab_size=50265, + vocab_size=roberta_sent_encoder.embed_tokens.num_embeddings, hidden_size=roberta.args.encoder_embed_dim, num_hidden_layers=roberta.args.encoder_layers, num_attention_heads=roberta.args.encoder_attention_heads, @@ -64,7 +72,6 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_ # Now let's copy all the weights. # Embeddings - roberta_sent_encoder = roberta.model.decoder.sentence_encoder model.roberta.embeddings.word_embeddings.weight = roberta_sent_encoder.embed_tokens.weight model.roberta.embeddings.position_embeddings.weight = roberta_sent_encoder.embed_positions.weight model.roberta.embeddings.token_type_embeddings.weight.data = torch.zeros_like(model.roberta.embeddings.token_type_embeddings.weight) # just zero them out b/c RoBERTa doesn't use them. @@ -79,15 +86,18 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_ ### self attention self_attn: BertSelfAttention = layer.attention.self assert( - roberta_layer.self_attn.in_proj_weight.shape == torch.Size((3 * config.hidden_size, config.hidden_size)) + roberta_layer.self_attn.k_proj.weight.data.shape == \ + roberta_layer.self_attn.q_proj.weight.data.shape == \ + roberta_layer.self_attn.v_proj.weight.data.shape == \ + torch.Size((config.hidden_size, config.hidden_size)) ) - # we use three distinct linear layers so we split the source layer here. - self_attn.query.weight.data = roberta_layer.self_attn.in_proj_weight[:config.hidden_size, :] - self_attn.query.bias.data = roberta_layer.self_attn.in_proj_bias[:config.hidden_size] - self_attn.key.weight.data = roberta_layer.self_attn.in_proj_weight[config.hidden_size:2*config.hidden_size, :] - self_attn.key.bias.data = roberta_layer.self_attn.in_proj_bias[config.hidden_size:2*config.hidden_size] - self_attn.value.weight.data = roberta_layer.self_attn.in_proj_weight[2*config.hidden_size:, :] - self_attn.value.bias.data = roberta_layer.self_attn.in_proj_bias[2*config.hidden_size:] + + self_attn.query.weight.data = roberta_layer.self_attn.q_proj.weight + self_attn.query.bias.data = roberta_layer.self_attn.q_proj.bias + self_attn.key.weight.data = roberta_layer.self_attn.k_proj.weight + self_attn.key.bias.data = roberta_layer.self_attn.k_proj.bias + self_attn.value.weight.data = roberta_layer.self_attn.v_proj.weight + self_attn.value.bias.data = roberta_layer.self_attn.v_proj.bias ### self-attention output self_output: BertSelfOutput = layer.attention.output @@ -151,6 +161,7 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_ if not success: raise Exception("Something went wRoNg") + pathlib.Path(pytorch_dump_folder_path).mkdir(parents=True, exist_ok=True) print(f"Saving model to {pytorch_dump_folder_path}") model.save_pretrained(pytorch_dump_folder_path) diff --git a/transformers/file_utils.py b/transformers/file_utils.py index 81c9b8002f5..16010f7e0ad 100644 --- a/transformers/file_utils.py +++ b/transformers/file_utils.py @@ -26,14 +26,6 @@ from contextlib import contextmanager logger = logging.getLogger(__name__) # pylint: disable=invalid-name -try: - import tensorflow as tf - assert hasattr(tf, '__version__') and int(tf.__version__[0]) >= 2 - _tf_available = True # pylint: disable=invalid-name - logger.info("TensorFlow version {} available.".format(tf.__version__)) -except (ImportError, AssertionError): - _tf_available = False # pylint: disable=invalid-name - try: import torch _torch_available = True # pylint: disable=invalid-name @@ -41,6 +33,13 @@ try: except ImportError: _torch_available = False # pylint: disable=invalid-name +try: + import tensorflow as tf + assert hasattr(tf, '__version__') and int(tf.__version__[0]) >= 2 + _tf_available = True # pylint: disable=invalid-name + logger.info("TensorFlow version {} available.".format(tf.__version__)) +except (ImportError, AssertionError): + _tf_available = False # pylint: disable=invalid-name try: from torch.hub import _get_torch_home diff --git a/transformers/modeling_bert.py b/transformers/modeling_bert.py index d0f35272ac0..afeb9d8e21c 100644 --- a/transformers/modeling_bert.py +++ b/transformers/modeling_bert.py @@ -51,7 +51,9 @@ BERT_PRETRAINED_MODEL_ARCHIVE_MAP = { 'bert-base-japanese': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-pytorch_model.bin", 'bert-base-japanese-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking-pytorch_model.bin", 'bert-base-japanese-char': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-pytorch_model.bin", - 'bert-base-japanese-char-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking-pytorch_model.bin" + 'bert-base-japanese-char-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking-pytorch_model.bin", + 'bert-base-finnish-cased-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/pytorch_model.bin", + 'bert-base-finnish-uncased-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/pytorch_model.bin", } diff --git a/transformers/modeling_tf_bert.py b/transformers/modeling_tf_bert.py index 7cc71f50633..b4f97c06d9f 100644 --- a/transformers/modeling_tf_bert.py +++ b/transformers/modeling_tf_bert.py @@ -51,7 +51,9 @@ TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP = { 'bert-base-japanese': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-tf_model.h5", 'bert-base-japanese-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking-tf_model.h5", 'bert-base-japanese-char': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-tf_model.h5", - 'bert-base-japanese-char-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking-tf_model.h5" + 'bert-base-japanese-char-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking-tf_model.h5", + 'bert-base-finnish-cased-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/tf_model.h5", + 'bert-base-finnish-uncased-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/tf_model.h5", } diff --git a/transformers/tokenization_bert.py b/transformers/tokenization_bert.py index ded5072e588..edc26d88cf9 100644 --- a/transformers/tokenization_bert.py +++ b/transformers/tokenization_bert.py @@ -46,6 +46,8 @@ PRETRAINED_VOCAB_FILES_MAP = { 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt", 'bert-base-german-dbmdz-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-vocab.txt", 'bert-base-german-dbmdz-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-vocab.txt", + 'bert-base-finnish-cased-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/vocab.txt", + 'bert-base-finnish-uncased-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/vocab.txt", } } @@ -65,6 +67,8 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 'bert-base-cased-finetuned-mrpc': 512, 'bert-base-german-dbmdz-cased': 512, 'bert-base-german-dbmdz-uncased': 512, + 'bert-base-finnish-cased-v1': 512, + 'bert-base-finnish-uncased-v1': 512, } PRETRAINED_INIT_CONFIGURATION = { @@ -83,6 +87,8 @@ PRETRAINED_INIT_CONFIGURATION = { 'bert-base-cased-finetuned-mrpc': {'do_lower_case': False}, 'bert-base-german-dbmdz-cased': {'do_lower_case': False}, 'bert-base-german-dbmdz-uncased': {'do_lower_case': True}, + 'bert-base-finnish-cased-v1': {'do_lower_case': False}, + 'bert-base-finnish-uncased-v1': {'do_lower_case': True}, } @@ -113,12 +119,12 @@ class BertTokenizer(PreTrainedTokenizer): Args: vocab_file: Path to a one-wordpiece-per-line vocabulary file - do_lower_case: Whether to lower case the input. Only has an effect when do_wordpiece_only=False + do_lower_case: Whether to lower case the input. Only has an effect when do_basic_tokenize=True do_basic_tokenize: Whether to do basic tokenization before wordpiece. max_len: An artificial maximum length to truncate tokenized sequences to; Effective maximum length is always the minimum of this value (if specified) and the underlying BERT model's sequence length. never_split: List of tokens which will never be split during tokenization. Only has an effect when - do_wordpiece_only=False + do_basic_tokenize=True """ vocab_files_names = VOCAB_FILES_NAMES diff --git a/transformers/tokenization_distilbert.py b/transformers/tokenization_distilbert.py index f40bf2bd77e..2f245d71dca 100644 --- a/transformers/tokenization_distilbert.py +++ b/transformers/tokenization_distilbert.py @@ -53,12 +53,12 @@ class DistilBertTokenizer(BertTokenizer): Args: vocab_file: Path to a one-wordpiece-per-line vocabulary file - do_lower_case: Whether to lower case the input. Only has an effect when do_wordpiece_only=False + do_lower_case: Whether to lower case the input. Only has an effect when do_basic_tokenize=True do_basic_tokenize: Whether to do basic tokenization before wordpiece. max_len: An artificial maximum length to truncate tokenized sequences to; Effective maximum length is always the minimum of this value (if specified) and the underlying BERT model's sequence length. never_split: List of tokens which will never be split during tokenization. Only has an effect when - do_wordpiece_only=False + do_basic_tokenize=True """ vocab_files_names = VOCAB_FILES_NAMES