mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add MobileBert (#4901)
* Add MobileBert * Quality + Conversion script * style * Update src/transformers/modeling_mobilebert.py * Links to S3 * Style * TFMobileBert Slight fixes to the pytorch MobileBert Style * MobileBertForMaskedLM (PT + TF) * MobileBertForNextSentencePrediction (PT + TF) * MobileFor{MultipleChoice, TokenClassification} (PT + TF) ss * Tests + Auto * Doc * Tests * Addressing @sgugger's comments * Adressing @patrickvonplaten's comments * Style * Style * Integration test * style * Model card Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr> Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
parent
f45e873910
commit
9a3f91088c
@ -184,3 +184,4 @@ conversion utilities for the following models:
|
||||
model_doc/marian
|
||||
model_doc/longformer
|
||||
model_doc/retribert
|
||||
model_doc/mobilebert
|
||||
|
169
docs/source/model_doc/mobilebert.rst
Normal file
169
docs/source/model_doc/mobilebert.rst
Normal file
@ -0,0 +1,169 @@
|
||||
MobileBERT
|
||||
----------------------------------------------------
|
||||
|
||||
Overview
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
The MobileBERT model was proposed in `MobileBERT: a Compact Task-Agnostic BERT
|
||||
for Resource-Limited Devices <https://arxiv.org/abs/2004.02984>`__
|
||||
by Zhiqing Sun, Hongkun Yu, Xiaodan Song, Renjie Liu, Yiming Yang, and Denny Zhou. It's a bidirectional transformer
|
||||
based on the BERT model, which is compressed and accelerated using several approaches.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*Natural Language Processing (NLP) has recently achieved great success by using huge pre-trained models with hundreds
|
||||
of millions of parameters. However, these models suffer from heavy model sizes and high latency such that they cannot
|
||||
be deployed to resource-limited mobile devices. In this paper, we propose MobileBERT for compressing and accelerating
|
||||
the popular BERT model. Like the original BERT, MobileBERT is task-agnostic, that is, it can be generically applied
|
||||
to various downstream NLP tasks via simple fine-tuning. Basically, MobileBERT is a thin version of BERT_LARGE, while
|
||||
equipped with bottleneck structures and a carefully designed balance between self-attentions and feed-forward
|
||||
networks. To train MobileBERT, we first train a specially designed teacher model, an inverted-bottleneck incorporated
|
||||
BERT_LARGE model. Then, we conduct knowledge transfer from this teacher to MobileBERT. Empirical studies show that
|
||||
MobileBERT is 4.3x smaller and 5.5x faster than BERT_BASE while achieving competitive results on well-known
|
||||
benchmarks. On the natural language inference tasks of GLUE, MobileBERT achieves a GLUEscore o 77.7
|
||||
(0.6 lower than BERT_BASE), and 62 ms latency on a Pixel 4 phone. On the SQuAD v1.1/v2.0 question answering task,
|
||||
MobileBERT achieves a dev F1 score of 90.0/79.2 (1.5/2.1 higher than BERT_BASE).*
|
||||
|
||||
Tips:
|
||||
|
||||
- MobileBERT is a model with absolute position embeddings so it's usually advised to pad the inputs on
|
||||
the right rather than the left.
|
||||
- MobileBERT is similar to BERT and therefore relies on the masked language modeling (MLM) objective.
|
||||
It is therefore efficient at predicting masked tokens and at NLU in general, but is not optimal for
|
||||
text generation. Models trained with a causal language modeling (CLM) objective are better in that regard.
|
||||
|
||||
The original code can be found `here <https://github.com/google-research/mobilebert>`_.
|
||||
|
||||
MobileBertConfig
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.MobileBertConfig
|
||||
:members:
|
||||
|
||||
|
||||
MobileBertTokenizer
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.MobileBertTokenizer
|
||||
:members: build_inputs_with_special_tokens, get_special_tokens_mask,
|
||||
create_token_type_ids_from_sequences, save_vocabulary
|
||||
|
||||
|
||||
MobileBertTokenizerFast
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.MobileBertTokenizerFast
|
||||
:members:
|
||||
|
||||
|
||||
MobileBertModel
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.MobileBertModel
|
||||
:members:
|
||||
|
||||
|
||||
MobileBertForPreTraining
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.MobileBertForPreTraining
|
||||
:members:
|
||||
|
||||
|
||||
MobileBertForMaskedLM
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.MobileBertForMaskedLM
|
||||
:members:
|
||||
|
||||
|
||||
MobileBertForNextSentencePrediction
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.MobileBertForNextSentencePrediction
|
||||
:members:
|
||||
|
||||
|
||||
MobileBertForSequenceClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.MobileBertForSequenceClassification
|
||||
:members:
|
||||
|
||||
|
||||
MobileBertForMultipleChoice
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.MobileBertForMultipleChoice
|
||||
:members:
|
||||
|
||||
|
||||
MobileBertForTokenClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.MobileBertForTokenClassification
|
||||
:members:
|
||||
|
||||
|
||||
MobileBertForQuestionAnswering
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.MobileBertForQuestionAnswering
|
||||
:members:
|
||||
|
||||
|
||||
TFMobileBertModel
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFMobileBertModel
|
||||
:members:
|
||||
|
||||
|
||||
TFMobileBertForPreTraining
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFMobileBertForPreTraining
|
||||
:members:
|
||||
|
||||
|
||||
TFMobileBertForMaskedLM
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFMobileBertForMaskedLM
|
||||
:members:
|
||||
|
||||
|
||||
TFMobileBertForNextSentencePrediction
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFMobileBertForNextSentencePrediction
|
||||
:members:
|
||||
|
||||
|
||||
TFMobileBertForSequenceClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFMobileBertForSequenceClassification
|
||||
:members:
|
||||
|
||||
|
||||
TFMobileBertForMultipleChoice
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFMobileBertForMultipleChoice
|
||||
:members:
|
||||
|
||||
|
||||
TFMobileBertForTokenClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFMobileBertForTokenClassification
|
||||
:members:
|
||||
|
||||
|
||||
TFMobileBertForQuestionAnswering
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFMobileBertForQuestionAnswering
|
||||
:members:
|
||||
|
32
model_cards/google/mobilebert-uncased/README.md
Normal file
32
model_cards/google/mobilebert-uncased/README.md
Normal file
@ -0,0 +1,32 @@
|
||||
---
|
||||
language: english
|
||||
thumbnail: https://huggingface.co/front/thumbnails/google.png
|
||||
|
||||
license: apache-2.0
|
||||
---
|
||||
|
||||
## MobileBERT: a Compact Task-Agnostic BERT for Resource-Limited Devices
|
||||
|
||||
MobileBERT is a thin version of BERT_LARGE, while equipped with bottleneck structures and a carefully designed balance
|
||||
between self-attentions and feed-forward networks.
|
||||
|
||||
This checkpoint is the original MobileBert Optimized Uncased English:
|
||||
[uncased_L-24_H-128_B-512_A-4_F-4_OPT](https://storage.googleapis.com/cloud-tpu-checkpoints/mobilebert/uncased_L-24_H-128_B-512_A-4_F-4_OPT.tar.gz)
|
||||
checkpoint.
|
||||
|
||||
## How to use MobileBERT in `transformers`
|
||||
|
||||
```python
|
||||
from transformers import pipeline
|
||||
|
||||
fill_mask = pipeline(
|
||||
"fill-mask",
|
||||
model="google/mobilebert-uncased",
|
||||
tokenizer="google/mobilebert-uncased"
|
||||
)
|
||||
|
||||
print(
|
||||
fill_mask(f"HuggingFace is creating a {fill_mask.tokenizer.mask_token} that the community uses to solve NLP tasks.")
|
||||
)
|
||||
|
||||
```
|
@ -34,6 +34,7 @@ from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
|
||||
from .configuration_longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig
|
||||
from .configuration_marian import MarianConfig
|
||||
from .configuration_mmbt import MMBTConfig
|
||||
from .configuration_mobilebert import MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, MobileBertConfig
|
||||
from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
|
||||
from .configuration_reformer import REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, ReformerConfig
|
||||
from .configuration_retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig
|
||||
@ -129,6 +130,7 @@ from .tokenization_electra import ElectraTokenizer, ElectraTokenizerFast
|
||||
from .tokenization_flaubert import FlaubertTokenizer
|
||||
from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast
|
||||
from .tokenization_longformer import LongformerTokenizer, LongformerTokenizerFast
|
||||
from .tokenization_mobilebert import MobileBertTokenizer, MobileBertTokenizerFast
|
||||
from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast
|
||||
from .tokenization_reformer import ReformerTokenizer
|
||||
from .tokenization_retribert import RetriBertTokenizer, RetriBertTokenizerFast
|
||||
@ -188,6 +190,21 @@ if is_torch_available():
|
||||
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
||||
)
|
||||
|
||||
from .modeling_mobilebert import (
|
||||
MobileBertPreTrainedModel,
|
||||
MobileBertModel,
|
||||
MobileBertForPreTraining,
|
||||
MobileBertForSequenceClassification,
|
||||
MobileBertForQuestionAnswering,
|
||||
MobileBertForMaskedLM,
|
||||
MobileBertForNextSentencePrediction,
|
||||
MobileBertForMultipleChoice,
|
||||
MobileBertForTokenClassification,
|
||||
load_tf_weights_in_mobilebert,
|
||||
MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
MobileBertLayer,
|
||||
)
|
||||
|
||||
from .modeling_bert import (
|
||||
BertPreTrainedModel,
|
||||
BertModel,
|
||||
@ -495,6 +512,20 @@ if is_tf_available():
|
||||
TFGPT2PreTrainedModel,
|
||||
)
|
||||
|
||||
from .modeling_tf_mobilebert import (
|
||||
TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TFMobileBertModel,
|
||||
TFMobileBertPreTrainedModel,
|
||||
TFMobileBertForPreTraining,
|
||||
TFMobileBertForSequenceClassification,
|
||||
TFMobileBertForQuestionAnswering,
|
||||
TFMobileBertForMaskedLM,
|
||||
TFMobileBertForNextSentencePrediction,
|
||||
TFMobileBertForMultipleChoice,
|
||||
TFMobileBertForTokenClassification,
|
||||
TFMobileBertMainLayer,
|
||||
)
|
||||
|
||||
from .modeling_tf_openai import (
|
||||
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TFOpenAIGPTDoubleHeadsModel,
|
||||
|
@ -30,6 +30,7 @@ from .configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, Flau
|
||||
from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
|
||||
from .configuration_longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig
|
||||
from .configuration_marian import MarianConfig
|
||||
from .configuration_mobilebert import MobileBertConfig
|
||||
from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
|
||||
from .configuration_reformer import ReformerConfig
|
||||
from .configuration_retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig
|
||||
@ -75,6 +76,7 @@ CONFIG_MAPPING = OrderedDict(
|
||||
[
|
||||
("retribert", RetriBertConfig,),
|
||||
("t5", T5Config,),
|
||||
("mobilebert", MobileBertConfig,),
|
||||
("distilbert", DistilBertConfig,),
|
||||
("albert", AlbertConfig,),
|
||||
("camembert", CamembertConfig,),
|
||||
|
159
src/transformers/configuration_mobilebert.py
Normal file
159
src/transformers/configuration_mobilebert.py
Normal file
@ -0,0 +1,159 @@
|
||||
# coding=utf-8
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" MobileBERT model configuration """
|
||||
|
||||
import logging
|
||||
|
||||
from .configuration_utils import PretrainedConfig
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
"mobilebert-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/google/mobilebert-uncased/config.json"
|
||||
}
|
||||
|
||||
|
||||
class MobileBertConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a :class:`~transformers.MobileBertModel`.
|
||||
It is used to instantiate a MobileBERT model according to the specified arguments, defining the model
|
||||
architecture.
|
||||
|
||||
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used
|
||||
to control the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig`
|
||||
for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (:obj:`int`, optional, defaults to 30522):
|
||||
Vocabulary size of the MobileBERT model. Defines the different tokens that
|
||||
can be represented by the `inputs_ids` passed to the forward method of :class:`~transformers.MobileBertModel`.
|
||||
hidden_size (:obj:`int`, optional, defaults to 512):
|
||||
Dimensionality of the encoder layers and the pooler layer.
|
||||
num_hidden_layers (:obj:`int`, optional, defaults to 24):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (:obj:`int`, optional, defaults to 4):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
intermediate_size (:obj:`int`, optional, defaults to 512):
|
||||
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
||||
hidden_act (:obj:`str` or :obj:`function`, optional, defaults to "relu"):
|
||||
The non-linear activation function (function or string) in the encoder and pooler.
|
||||
If string, "gelu", "relu", "swish" and "gelu_new" are supported.
|
||||
hidden_dropout_prob (:obj:`float`, optional, defaults to 0.0):
|
||||
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
attention_probs_dropout_prob (:obj:`float`, optional, defaults to 0.1):
|
||||
The dropout ratio for the attention probabilities.
|
||||
max_position_embeddings (:obj:`int`, optional, defaults to 512):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
|
||||
type_vocab_size (:obj:`int`, optional, defaults to 2):
|
||||
The vocabulary size of the `token_type_ids` passed into :class:`~transformers.MobileBertModel`.
|
||||
initializer_range (:obj:`float`, optional, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
layer_norm_eps (:obj:`float`, optional, defaults to 1e-12):
|
||||
The epsilon used by the layer normalization layers.
|
||||
|
||||
pad_token_id (:obj:`int`, optional, defaults to 0):
|
||||
The ID of the token in the word embedding to use as padding.
|
||||
embedding_size (:obj:`int`, optional, defaults to 128):
|
||||
The dimension of the word embedding vectors.
|
||||
trigram_input (:obj:`bool`, optional, defaults to True):
|
||||
Use a convolution of trigram as input.
|
||||
use_bottleneck (:obj:`bool`, optional, defaults to True):
|
||||
Whether to use bottleneck in BERT.
|
||||
intra_bottleneck_size (:obj:`int`, optional, defaults to 128):
|
||||
Size of bottleneck layer output.
|
||||
use_bottleneck_attention (:obj:`bool`, optional, defaults to False):
|
||||
Whether to use attention inputs from the bottleneck transformation.
|
||||
key_query_shared_bottleneck (:obj:`bool`, optional, defaults to True):
|
||||
Whether to use the same linear transformation for query&key in the bottleneck.
|
||||
num_feedforward_networks (:obj:`int`, optional, defaults to 4):
|
||||
Number of FFNs in a block.
|
||||
normalization_type (:obj:`str`, optional, defaults to "no_norm"):
|
||||
The normalization type in BERT.
|
||||
|
||||
Example:
|
||||
|
||||
from transformers import MobileBertModel, MobileBertConfig
|
||||
|
||||
# Initializing a MobileBERT configuration
|
||||
configuration = MobileBertConfig()
|
||||
|
||||
# Initializing a model from the configuration above
|
||||
model = MobileBertModel(configuration)
|
||||
|
||||
# Accessing the model configuration
|
||||
configuration = model.config
|
||||
|
||||
Attributes:
|
||||
pretrained_config_archive_map (Dict[str, str]):
|
||||
A dictionary containing all the available pre-trained checkpoints.
|
||||
"""
|
||||
pretrained_config_archive_map = MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
model_type = "mobilebert"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=30522,
|
||||
hidden_size=512,
|
||||
num_hidden_layers=24,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=512,
|
||||
hidden_act="relu",
|
||||
hidden_dropout_prob=0.0,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
layer_norm_eps=1e-12,
|
||||
pad_token_id=0,
|
||||
embedding_size=128,
|
||||
trigram_input=True,
|
||||
use_bottleneck=True,
|
||||
intra_bottleneck_size=128,
|
||||
use_bottleneck_attention=False,
|
||||
key_query_shared_bottleneck=True,
|
||||
num_feedforward_networks=4,
|
||||
normalization_type="no_norm",
|
||||
classifier_activation=True,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.initializer_range = initializer_range
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.embedding_size = embedding_size
|
||||
self.trigram_input = trigram_input
|
||||
self.use_bottleneck = use_bottleneck
|
||||
self.intra_bottleneck_size = intra_bottleneck_size
|
||||
self.use_bottleneck_attention = use_bottleneck_attention
|
||||
self.key_query_shared_bottleneck = key_query_shared_bottleneck
|
||||
self.num_feedforward_networks = num_feedforward_networks
|
||||
self.normalization_type = normalization_type
|
||||
self.classifier_activation = classifier_activation
|
||||
|
||||
if self.use_bottleneck:
|
||||
self.true_hidden_size = intra_bottleneck_size
|
||||
else:
|
||||
self.true_hidden_size = hidden_size
|
@ -0,0 +1,42 @@
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import MobileBertConfig, MobileBertForPreTraining, load_tf_weights_in_mobilebert
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, mobilebert_config_file, pytorch_dump_path):
|
||||
# Initialise PyTorch model
|
||||
config = MobileBertConfig.from_json_file(mobilebert_config_file)
|
||||
print("Building PyTorch model from configuration: {}".format(str(config)))
|
||||
model = MobileBertForPreTraining(config)
|
||||
# Load weights from tf checkpoint
|
||||
model = load_tf_weights_in_mobilebert(model, config, tf_checkpoint_path)
|
||||
# Save pytorch-model
|
||||
print("Save PyTorch model to {}".format(pytorch_dump_path))
|
||||
torch.save(model.state_dict(), pytorch_dump_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mobilebert_config_file",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The config json file corresponding to the pre-trained MobileBERT model. \n"
|
||||
"This specifies the model architecture.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.mobilebert_config_file, args.pytorch_dump_path)
|
@ -32,6 +32,7 @@ from .configuration_auto import (
|
||||
FlaubertConfig,
|
||||
GPT2Config,
|
||||
LongformerConfig,
|
||||
MobileBertConfig,
|
||||
OpenAIGPTConfig,
|
||||
ReformerConfig,
|
||||
RetriBertConfig,
|
||||
@ -111,6 +112,15 @@ from .modeling_longformer import (
|
||||
LongformerModel,
|
||||
)
|
||||
from .modeling_marian import MarianMTModel
|
||||
from .modeling_mobilebert import (
|
||||
MobileBertForMaskedLM,
|
||||
MobileBertForMultipleChoice,
|
||||
MobileBertForPreTraining,
|
||||
MobileBertForQuestionAnswering,
|
||||
MobileBertForSequenceClassification,
|
||||
MobileBertForTokenClassification,
|
||||
MobileBertModel,
|
||||
)
|
||||
from .modeling_openai import OpenAIGPTLMHeadModel, OpenAIGPTModel
|
||||
from .modeling_reformer import ReformerModel, ReformerModelWithLMHead
|
||||
from .modeling_retribert import RetriBertModel
|
||||
@ -166,6 +176,7 @@ MODEL_MAPPING = OrderedDict(
|
||||
(BertConfig, BertModel),
|
||||
(OpenAIGPTConfig, OpenAIGPTModel),
|
||||
(GPT2Config, GPT2Model),
|
||||
(MobileBertConfig, MobileBertModel),
|
||||
(TransfoXLConfig, TransfoXLModel),
|
||||
(XLNetConfig, XLNetModel),
|
||||
(FlaubertConfig, FlaubertModel),
|
||||
@ -190,6 +201,7 @@ MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
|
||||
(BertConfig, BertForPreTraining),
|
||||
(OpenAIGPTConfig, OpenAIGPTLMHeadModel),
|
||||
(GPT2Config, GPT2LMHeadModel),
|
||||
(MobileBertConfig, MobileBertForPreTraining),
|
||||
(TransfoXLConfig, TransfoXLLMHeadModel),
|
||||
(XLNetConfig, XLNetLMHeadModel),
|
||||
(FlaubertConfig, FlaubertWithLMHeadModel),
|
||||
@ -213,6 +225,7 @@ MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
|
||||
(BertConfig, BertForMaskedLM),
|
||||
(OpenAIGPTConfig, OpenAIGPTLMHeadModel),
|
||||
(GPT2Config, GPT2LMHeadModel),
|
||||
(MobileBertConfig, MobileBertForMaskedLM),
|
||||
(TransfoXLConfig, TransfoXLLMHeadModel),
|
||||
(XLNetConfig, XLNetLMHeadModel),
|
||||
(FlaubertConfig, FlaubertWithLMHeadModel),
|
||||
@ -249,6 +262,7 @@ MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
|
||||
(LongformerConfig, LongformerForMaskedLM),
|
||||
(RobertaConfig, RobertaForMaskedLM),
|
||||
(BertConfig, BertForMaskedLM),
|
||||
(MobileBertConfig, MobileBertForMaskedLM),
|
||||
(FlaubertConfig, FlaubertWithLMHeadModel),
|
||||
(XLMConfig, XLMWithLMHeadModel),
|
||||
(ElectraConfig, ElectraForMaskedLM),
|
||||
@ -275,6 +289,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
|
||||
(RobertaConfig, RobertaForSequenceClassification),
|
||||
(BertConfig, BertForSequenceClassification),
|
||||
(XLNetConfig, XLNetForSequenceClassification),
|
||||
(MobileBertConfig, MobileBertForSequenceClassification),
|
||||
(FlaubertConfig, FlaubertForSequenceClassification),
|
||||
(XLMConfig, XLMForSequenceClassification),
|
||||
(ElectraConfig, ElectraForSequenceClassification),
|
||||
@ -292,6 +307,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
|
||||
(BertConfig, BertForQuestionAnswering),
|
||||
(XLNetConfig, XLNetForQuestionAnsweringSimple),
|
||||
(FlaubertConfig, FlaubertForQuestionAnsweringSimple),
|
||||
(MobileBertConfig, MobileBertForQuestionAnswering),
|
||||
(XLMConfig, XLMForQuestionAnsweringSimple),
|
||||
(ElectraConfig, ElectraForQuestionAnswering),
|
||||
]
|
||||
@ -306,6 +322,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(
|
||||
(LongformerConfig, LongformerForTokenClassification),
|
||||
(RobertaConfig, RobertaForTokenClassification),
|
||||
(BertConfig, BertForTokenClassification),
|
||||
(MobileBertConfig, MobileBertForTokenClassification),
|
||||
(XLNetConfig, XLNetForTokenClassification),
|
||||
(AlbertConfig, AlbertForTokenClassification),
|
||||
(ElectraConfig, ElectraForTokenClassification),
|
||||
@ -322,6 +339,7 @@ MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
|
||||
(RobertaConfig, RobertaForMultipleChoice),
|
||||
(BertConfig, BertForMultipleChoice),
|
||||
(DistilBertConfig, DistilBertForMultipleChoice),
|
||||
(MobileBertConfig, MobileBertForMultipleChoice),
|
||||
(XLNetConfig, XLNetForMultipleChoice),
|
||||
(AlbertConfig, AlbertForMultipleChoice),
|
||||
]
|
||||
|
1614
src/transformers/modeling_mobilebert.py
Normal file
1614
src/transformers/modeling_mobilebert.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -28,6 +28,7 @@ from .configuration_auto import (
|
||||
ElectraConfig,
|
||||
FlaubertConfig,
|
||||
GPT2Config,
|
||||
MobileBertConfig,
|
||||
OpenAIGPTConfig,
|
||||
RobertaConfig,
|
||||
T5Config,
|
||||
@ -88,6 +89,15 @@ from .modeling_tf_flaubert import (
|
||||
TFFlaubertWithLMHeadModel,
|
||||
)
|
||||
from .modeling_tf_gpt2 import TFGPT2LMHeadModel, TFGPT2Model
|
||||
from .modeling_tf_mobilebert import (
|
||||
TFMobileBertForMaskedLM,
|
||||
TFMobileBertForMultipleChoice,
|
||||
TFMobileBertForPreTraining,
|
||||
TFMobileBertForQuestionAnswering,
|
||||
TFMobileBertForSequenceClassification,
|
||||
TFMobileBertForTokenClassification,
|
||||
TFMobileBertModel,
|
||||
)
|
||||
from .modeling_tf_openai import TFOpenAIGPTLMHeadModel, TFOpenAIGPTModel
|
||||
from .modeling_tf_roberta import (
|
||||
TFRobertaForMaskedLM,
|
||||
@ -138,6 +148,7 @@ TF_MODEL_MAPPING = OrderedDict(
|
||||
(ElectraConfig, TFElectraModel),
|
||||
(FlaubertConfig, TFFlaubertModel),
|
||||
(GPT2Config, TFGPT2Model),
|
||||
(MobileBertConfig, TFMobileBertModel),
|
||||
(OpenAIGPTConfig, TFOpenAIGPTModel),
|
||||
(RobertaConfig, TFRobertaModel),
|
||||
(T5Config, TFT5Model),
|
||||
@ -158,6 +169,7 @@ TF_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
|
||||
(ElectraConfig, TFElectraForPreTraining),
|
||||
(FlaubertConfig, TFFlaubertWithLMHeadModel),
|
||||
(GPT2Config, TFGPT2LMHeadModel),
|
||||
(MobileBertConfig, TFMobileBertForPreTraining),
|
||||
(OpenAIGPTConfig, TFOpenAIGPTLMHeadModel),
|
||||
(RobertaConfig, TFRobertaForMaskedLM),
|
||||
(T5Config, TFT5ForConditionalGeneration),
|
||||
@ -178,6 +190,7 @@ TF_MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
|
||||
(ElectraConfig, TFElectraForMaskedLM),
|
||||
(FlaubertConfig, TFFlaubertWithLMHeadModel),
|
||||
(GPT2Config, TFGPT2LMHeadModel),
|
||||
(MobileBertConfig, TFMobileBertForMaskedLM),
|
||||
(OpenAIGPTConfig, TFOpenAIGPTLMHeadModel),
|
||||
(RobertaConfig, TFRobertaForMaskedLM),
|
||||
(T5Config, TFT5ForConditionalGeneration),
|
||||
@ -195,6 +208,7 @@ TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
|
||||
(CamembertConfig, TFCamembertForMultipleChoice),
|
||||
(DistilBertConfig, TFDistilBertForMultipleChoice),
|
||||
(FlaubertConfig, TFFlaubertForMultipleChoice),
|
||||
(MobileBertConfig, TFMobileBertForMultipleChoice),
|
||||
(RobertaConfig, TFRobertaForMultipleChoice),
|
||||
(XLMConfig, TFXLMForMultipleChoice),
|
||||
(XLMRobertaConfig, TFXLMRobertaForMultipleChoice),
|
||||
@ -210,6 +224,7 @@ TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
|
||||
(DistilBertConfig, TFDistilBertForQuestionAnswering),
|
||||
(ElectraConfig, TFElectraForQuestionAnswering),
|
||||
(FlaubertConfig, TFFlaubertForQuestionAnsweringSimple),
|
||||
(MobileBertConfig, TFMobileBertForQuestionAnswering),
|
||||
(RobertaConfig, TFRobertaForQuestionAnswering),
|
||||
(XLMConfig, TFXLMForQuestionAnsweringSimple),
|
||||
(XLMRobertaConfig, TFXLMRobertaForQuestionAnswering),
|
||||
@ -224,6 +239,7 @@ TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
|
||||
(CamembertConfig, TFCamembertForSequenceClassification),
|
||||
(DistilBertConfig, TFDistilBertForSequenceClassification),
|
||||
(FlaubertConfig, TFFlaubertForSequenceClassification),
|
||||
(MobileBertConfig, TFMobileBertForSequenceClassification),
|
||||
(RobertaConfig, TFRobertaForSequenceClassification),
|
||||
(XLMConfig, TFXLMForSequenceClassification),
|
||||
(XLMRobertaConfig, TFXLMRobertaForSequenceClassification),
|
||||
@ -239,6 +255,7 @@ TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(
|
||||
(DistilBertConfig, TFDistilBertForTokenClassification),
|
||||
(ElectraConfig, TFElectraForTokenClassification),
|
||||
(FlaubertConfig, TFFlaubertForTokenClassification),
|
||||
(MobileBertConfig, TFMobileBertForTokenClassification),
|
||||
(RobertaConfig, TFRobertaForTokenClassification),
|
||||
(XLMConfig, TFXLMForTokenClassification),
|
||||
(XLMRobertaConfig, TFXLMRobertaForTokenClassification),
|
||||
|
1474
src/transformers/modeling_tf_mobilebert.py
Normal file
1474
src/transformers/modeling_tf_mobilebert.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -18,6 +18,8 @@
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
|
||||
from transformers.configuration_mobilebert import MobileBertConfig
|
||||
|
||||
from .configuration_auto import (
|
||||
AlbertConfig,
|
||||
AutoConfig,
|
||||
@ -55,6 +57,7 @@ from .tokenization_flaubert import FlaubertTokenizer
|
||||
from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast
|
||||
from .tokenization_longformer import LongformerTokenizer
|
||||
from .tokenization_marian import MarianTokenizer
|
||||
from .tokenization_mobilebert import MobileBertTokenizer, MobileBertTokenizerFast
|
||||
from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast
|
||||
from .tokenization_reformer import ReformerTokenizer
|
||||
from .tokenization_retribert import RetriBertTokenizer, RetriBertTokenizerFast
|
||||
@ -73,6 +76,7 @@ TOKENIZER_MAPPING = OrderedDict(
|
||||
[
|
||||
(RetriBertConfig, (RetriBertTokenizer, RetriBertTokenizerFast)),
|
||||
(T5Config, (T5Tokenizer, None)),
|
||||
(MobileBertConfig, (MobileBertTokenizer, MobileBertTokenizerFast)),
|
||||
(DistilBertConfig, (DistilBertTokenizer, DistilBertTokenizerFast)),
|
||||
(AlbertConfig, (AlbertTokenizer, None)),
|
||||
(CamembertConfig, (CamembertTokenizer, None)),
|
||||
|
69
src/transformers/tokenization_mobilebert.py
Normal file
69
src/transformers/tokenization_mobilebert.py
Normal file
@ -0,0 +1,69 @@
|
||||
# coding=utf-8
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tokenization classes for MobileBERT."""
|
||||
|
||||
|
||||
import logging
|
||||
|
||||
from .tokenization_bert import BertTokenizer, BertTokenizerFast
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
|
||||
|
||||
PRETRAINED_VOCAB_FILES_MAP = {
|
||||
"vocab_file": {
|
||||
"mobilebert-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/google/mobilebert-uncased/vocab.txt"
|
||||
}
|
||||
}
|
||||
|
||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {}
|
||||
|
||||
|
||||
PRETRAINED_INIT_CONFIGURATION = {}
|
||||
|
||||
|
||||
class MobileBertTokenizer(BertTokenizer):
|
||||
r"""
|
||||
Constructs a MobileBertTokenizer.
|
||||
|
||||
:class:`~transformers.MobileBertTokenizer is identical to :class:`~transformers.BertTokenizer` and runs end-to-end
|
||||
tokenization: punctuation splitting + wordpiece.
|
||||
|
||||
Refer to superclass :class:`~transformers.BertTokenizer` for usage examples and documentation concerning
|
||||
parameters.
|
||||
"""
|
||||
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
|
||||
|
||||
|
||||
class MobileBertTokenizerFast(BertTokenizerFast):
|
||||
r"""
|
||||
Constructs a "Fast" MobileBertTokenizer (backed by HuggingFace's `tokenizers` library).
|
||||
|
||||
:class:`~transformers.MobileBertTokenizerFast` is identical to :class:`~transformers.BertTokenizerFast` and runs end-to-end
|
||||
tokenization: punctuation splitting + wordpiece.
|
||||
|
||||
Refer to superclass :class:`~transformers.BertTokenizerFast` for usage examples and documentation concerning
|
||||
parameters.
|
||||
"""
|
||||
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
|
499
tests/test_modeling_mobilebert.py
Normal file
499
tests/test_modeling_mobilebert.py
Normal file
@ -0,0 +1,499 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||
from .utils import require_torch, slow, torch_device
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from transformers import (
|
||||
MobileBertConfig,
|
||||
MobileBertModel,
|
||||
MobileBertForMaskedLM,
|
||||
MobileBertForNextSentencePrediction,
|
||||
MobileBertForPreTraining,
|
||||
MobileBertForQuestionAnswering,
|
||||
MobileBertForSequenceClassification,
|
||||
MobileBertForTokenClassification,
|
||||
MobileBertForMultipleChoice,
|
||||
)
|
||||
from transformers.modeling_mobilebert import MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
|
||||
|
||||
class MobileBertModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_input_mask=True,
|
||||
use_token_type_ids=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=64,
|
||||
embedding_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=16,
|
||||
type_sequence_label_size=2,
|
||||
initializer_range=0.02,
|
||||
num_labels=3,
|
||||
num_choices=4,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_input_mask = use_input_mask
|
||||
self.use_token_type_ids = use_token_type_ids
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.embedding_size = embedding_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.initializer_range = initializer_range
|
||||
self.num_labels = num_labels
|
||||
self.num_choices = num_choices
|
||||
self.scope = scope
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
||||
|
||||
sequence_labels = None
|
||||
token_labels = None
|
||||
choice_labels = None
|
||||
if self.use_labels:
|
||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
||||
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
||||
|
||||
config = MobileBertConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
intermediate_size=self.intermediate_size,
|
||||
embedding_size=self.embedding_size,
|
||||
hidden_act=self.hidden_act,
|
||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
type_vocab_size=self.type_vocab_size,
|
||||
is_decoder=False,
|
||||
initializer_range=self.initializer_range,
|
||||
)
|
||||
|
||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
||||
def prepare_config_and_inputs_for_decoder(self):
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
) = self.prepare_config_and_inputs()
|
||||
|
||||
config.is_decoder = True
|
||||
encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
|
||||
encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||
|
||||
return (
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
)
|
||||
|
||||
def check_loss_output(self, result):
|
||||
self.parent.assertListEqual(list(result["loss"].size()), [])
|
||||
|
||||
def create_and_check_mobilebert_model(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = MobileBertModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
||||
sequence_output, pooled_output = model(input_ids, token_type_ids=token_type_ids)
|
||||
sequence_output, pooled_output = model(input_ids)
|
||||
|
||||
result = {
|
||||
"sequence_output": sequence_output,
|
||||
"pooled_output": pooled_output,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
)
|
||||
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
|
||||
|
||||
def create_and_check_mobilebert_model_as_decoder(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
):
|
||||
model = MobileBertModel(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
sequence_output, pooled_output = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
)
|
||||
sequence_output, pooled_output = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
||||
|
||||
result = {
|
||||
"sequence_output": sequence_output,
|
||||
"pooled_output": pooled_output,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
)
|
||||
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
|
||||
|
||||
def create_and_check_mobilebert_for_masked_lm(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = MobileBertForMaskedLM(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, prediction_scores = model(
|
||||
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"prediction_scores": prediction_scores,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size]
|
||||
)
|
||||
self.check_loss_output(result)
|
||||
|
||||
def create_and_check_mobilebert_for_next_sequence_prediction(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = MobileBertForNextSentencePrediction(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, seq_relationship_score = model(
|
||||
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, next_sentence_label=sequence_labels,
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"seq_relationship_score": seq_relationship_score,
|
||||
}
|
||||
self.parent.assertListEqual(list(result["seq_relationship_score"].size()), [self.batch_size, 2])
|
||||
self.check_loss_output(result)
|
||||
|
||||
def create_and_check_mobilebert_for_pretraining(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = MobileBertForPreTraining(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, prediction_scores, seq_relationship_score = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
labels=token_labels,
|
||||
next_sentence_label=sequence_labels,
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"prediction_scores": prediction_scores,
|
||||
"seq_relationship_score": seq_relationship_score,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size]
|
||||
)
|
||||
self.parent.assertListEqual(list(result["seq_relationship_score"].size()), [self.batch_size, 2])
|
||||
self.check_loss_output(result)
|
||||
|
||||
def create_and_check_mobilebert_for_question_answering(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = MobileBertForQuestionAnswering(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, start_logits, end_logits = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
start_positions=sequence_labels,
|
||||
end_positions=sequence_labels,
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"start_logits": start_logits,
|
||||
"end_logits": end_logits,
|
||||
}
|
||||
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
|
||||
self.check_loss_output(result)
|
||||
|
||||
def create_and_check_mobilebert_for_sequence_classification(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
model = MobileBertForSequenceClassification(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, logits = model(
|
||||
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"logits": logits,
|
||||
}
|
||||
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels])
|
||||
self.check_loss_output(result)
|
||||
|
||||
def create_and_check_mobilebert_for_token_classification(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
model = MobileBertForTokenClassification(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"logits": logits,
|
||||
}
|
||||
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels])
|
||||
self.check_loss_output(result)
|
||||
|
||||
def create_and_check_mobilebert_for_multiple_choice(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.num_choices = self.num_choices
|
||||
model = MobileBertForMultipleChoice(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
loss, logits = model(
|
||||
multiple_choice_inputs_ids,
|
||||
attention_mask=multiple_choice_input_mask,
|
||||
token_type_ids=multiple_choice_token_type_ids,
|
||||
labels=choice_labels,
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"logits": logits,
|
||||
}
|
||||
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_choices])
|
||||
self.check_loss_output(result)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
) = config_and_inputs
|
||||
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class MobileBertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (
|
||||
(
|
||||
MobileBertModel,
|
||||
MobileBertForMaskedLM,
|
||||
MobileBertForMultipleChoice,
|
||||
MobileBertForNextSentencePrediction,
|
||||
MobileBertForPreTraining,
|
||||
MobileBertForQuestionAnswering,
|
||||
MobileBertForSequenceClassification,
|
||||
MobileBertForTokenClassification,
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = MobileBertModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=MobileBertConfig, hidden_size=37)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_mobilebert_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_mobilebert_model(*config_and_inputs)
|
||||
|
||||
def test_mobilebert_model_as_decoder(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
|
||||
self.model_tester.create_and_check_mobilebert_model_as_decoder(*config_and_inputs)
|
||||
|
||||
def test_mobilebert_model_as_decoder_with_default_input_mask(self):
|
||||
# This regression test was failing with PyTorch < 1.3
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
) = self.model_tester.prepare_config_and_inputs_for_decoder()
|
||||
|
||||
input_mask = None
|
||||
|
||||
self.model_tester.create_and_check_mobilebert_model_as_decoder(
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
)
|
||||
|
||||
def test_for_masked_lm(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_mobilebert_for_masked_lm(*config_and_inputs)
|
||||
|
||||
def test_for_multiple_choice(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_mobilebert_for_multiple_choice(*config_and_inputs)
|
||||
|
||||
def test_for_next_sequence_prediction(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_mobilebert_for_next_sequence_prediction(*config_and_inputs)
|
||||
|
||||
def test_for_pretraining(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_mobilebert_for_pretraining(*config_and_inputs)
|
||||
|
||||
def test_for_question_answering(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_mobilebert_for_question_answering(*config_and_inputs)
|
||||
|
||||
def test_for_sequence_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_mobilebert_for_sequence_classification(*config_and_inputs)
|
||||
|
||||
def test_for_token_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_mobilebert_for_token_classification(*config_and_inputs)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
model = MobileBertModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
def _long_tensor(tok_lst):
|
||||
return torch.tensor(tok_lst, dtype=torch.long, device=torch_device,)
|
||||
|
||||
|
||||
TOLERANCE = 1e-3
|
||||
|
||||
|
||||
@require_torch
|
||||
class MobileBertModelIntegrationTests(unittest.TestCase):
|
||||
@slow
|
||||
def test_inference_no_head(self):
|
||||
model = MobileBertModel.from_pretrained("google/mobilebert-uncased").to(torch_device)
|
||||
input_ids = _long_tensor([[101, 7110, 1005, 1056, 2023, 11333, 17413, 1029, 102]])
|
||||
with torch.no_grad():
|
||||
output = model(input_ids)[0]
|
||||
expected_shape = torch.Size((1, 9, 512))
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
expected_slice = torch.tensor(
|
||||
[
|
||||
[
|
||||
[-2.4736526e07, 8.2691656e04, 1.6521838e05],
|
||||
[-5.7541704e-01, 3.9056022e00, 4.4011507e00],
|
||||
[2.6047359e00, 1.5677652e00, -1.7324188e-01],
|
||||
]
|
||||
],
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
# MobileBERT results range from 10e0 to 10e8. Even a 0.0000001% difference with a value of 10e8 results in a
|
||||
# ~1 difference, it's therefore not a good idea to measure using addition.
|
||||
# Here, we instead divide the expected result with the result in order to obtain ~1. We then check that the
|
||||
# result is held between bounds: 1 - TOLERANCE < expected_result / result < 1 + TOLERANCE
|
||||
lower_bound = torch.all((expected_slice / output[..., :3, :3]) >= 1 - TOLERANCE)
|
||||
upper_bound = torch.all((expected_slice / output[..., :3, :3]) <= 1 + TOLERANCE)
|
||||
|
||||
self.assertTrue(lower_bound and upper_bound)
|
321
tests/test_modeling_tf_mobilebert.py
Normal file
321
tests/test_modeling_tf_mobilebert.py
Normal file
@ -0,0 +1,321 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import MobileBertConfig, is_tf_available
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
||||
from .utils import require_tf, slow
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
from transformers.modeling_tf_mobilebert import (
|
||||
TFMobileBertModel,
|
||||
TFMobileBertForMaskedLM,
|
||||
TFMobileBertForNextSentencePrediction,
|
||||
TFMobileBertForPreTraining,
|
||||
TFMobileBertForSequenceClassification,
|
||||
TFMobileBertForMultipleChoice,
|
||||
TFMobileBertForTokenClassification,
|
||||
TFMobileBertForQuestionAnswering,
|
||||
)
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFMobileBertModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (
|
||||
(
|
||||
TFMobileBertModel,
|
||||
TFMobileBertForMaskedLM,
|
||||
TFMobileBertForNextSentencePrediction,
|
||||
TFMobileBertForPreTraining,
|
||||
TFMobileBertForQuestionAnswering,
|
||||
TFMobileBertForSequenceClassification,
|
||||
TFMobileBertForTokenClassification,
|
||||
TFMobileBertForMultipleChoice,
|
||||
)
|
||||
if is_tf_available()
|
||||
else ()
|
||||
)
|
||||
|
||||
class TFMobileBertModelTester(object):
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_input_mask=True,
|
||||
use_token_type_ids=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
embedding_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=16,
|
||||
type_sequence_label_size=2,
|
||||
initializer_range=0.02,
|
||||
num_labels=3,
|
||||
num_choices=4,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_input_mask = use_input_mask
|
||||
self.use_token_type_ids = use_token_type_ids
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.initializer_range = initializer_range
|
||||
self.num_labels = num_labels
|
||||
self.num_choices = num_choices
|
||||
self.scope = scope
|
||||
self.embedding_size = embedding_size
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
||||
|
||||
sequence_labels = None
|
||||
token_labels = None
|
||||
choice_labels = None
|
||||
if self.use_labels:
|
||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
||||
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
||||
|
||||
config = MobileBertConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
intermediate_size=self.intermediate_size,
|
||||
hidden_act=self.hidden_act,
|
||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
type_vocab_size=self.type_vocab_size,
|
||||
initializer_range=self.initializer_range,
|
||||
embedding_size=self.embedding_size,
|
||||
)
|
||||
|
||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
||||
def create_and_check_mobilebert_model(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = TFMobileBertModel(config=config)
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||
sequence_output, pooled_output = model(inputs)
|
||||
|
||||
inputs = [input_ids, input_mask]
|
||||
sequence_output, pooled_output = model(inputs)
|
||||
|
||||
sequence_output, pooled_output = model(input_ids)
|
||||
|
||||
result = {
|
||||
"sequence_output": sequence_output.numpy(),
|
||||
"pooled_output": pooled_output.numpy(),
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["sequence_output"].shape), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
)
|
||||
self.parent.assertListEqual(list(result["pooled_output"].shape), [self.batch_size, self.hidden_size])
|
||||
|
||||
def create_and_check_mobilebert_for_masked_lm(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = TFMobileBertForMaskedLM(config=config)
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||
(prediction_scores,) = model(inputs)
|
||||
result = {
|
||||
"prediction_scores": prediction_scores.numpy(),
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["prediction_scores"].shape), [self.batch_size, self.seq_length, self.vocab_size]
|
||||
)
|
||||
|
||||
def create_and_check_mobilebert_for_next_sequence_prediction(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = TFMobileBertForNextSentencePrediction(config=config)
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||
(seq_relationship_score,) = model(inputs)
|
||||
result = {
|
||||
"seq_relationship_score": seq_relationship_score.numpy(),
|
||||
}
|
||||
self.parent.assertListEqual(list(result["seq_relationship_score"].shape), [self.batch_size, 2])
|
||||
|
||||
def create_and_check_mobilebert_for_pretraining(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = TFMobileBertForPreTraining(config=config)
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||
prediction_scores, seq_relationship_score = model(inputs)
|
||||
result = {
|
||||
"prediction_scores": prediction_scores.numpy(),
|
||||
"seq_relationship_score": seq_relationship_score.numpy(),
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["prediction_scores"].shape), [self.batch_size, self.seq_length, self.vocab_size]
|
||||
)
|
||||
self.parent.assertListEqual(list(result["seq_relationship_score"].shape), [self.batch_size, 2])
|
||||
|
||||
def create_and_check_mobilebert_for_sequence_classification(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
model = TFMobileBertForSequenceClassification(config=config)
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||
(logits,) = model(inputs)
|
||||
result = {
|
||||
"logits": logits.numpy(),
|
||||
}
|
||||
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_labels])
|
||||
|
||||
def create_and_check_mobilebert_for_multiple_choice(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.num_choices = self.num_choices
|
||||
model = TFMobileBertForMultipleChoice(config=config)
|
||||
multiple_choice_inputs_ids = tf.tile(tf.expand_dims(input_ids, 1), (1, self.num_choices, 1))
|
||||
multiple_choice_input_mask = tf.tile(tf.expand_dims(input_mask, 1), (1, self.num_choices, 1))
|
||||
multiple_choice_token_type_ids = tf.tile(tf.expand_dims(token_type_ids, 1), (1, self.num_choices, 1))
|
||||
inputs = {
|
||||
"input_ids": multiple_choice_inputs_ids,
|
||||
"attention_mask": multiple_choice_input_mask,
|
||||
"token_type_ids": multiple_choice_token_type_ids,
|
||||
}
|
||||
(logits,) = model(inputs)
|
||||
result = {
|
||||
"logits": logits.numpy(),
|
||||
}
|
||||
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_choices])
|
||||
|
||||
def create_and_check_mobilebert_for_token_classification(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
model = TFMobileBertForTokenClassification(config=config)
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||
(logits,) = model(inputs)
|
||||
result = {
|
||||
"logits": logits.numpy(),
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["logits"].shape), [self.batch_size, self.seq_length, self.num_labels]
|
||||
)
|
||||
|
||||
def create_and_check_mobilebert_for_question_answering(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = TFMobileBertForQuestionAnswering(config=config)
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||
start_logits, end_logits = model(inputs)
|
||||
result = {
|
||||
"start_logits": start_logits.numpy(),
|
||||
"end_logits": end_logits.numpy(),
|
||||
}
|
||||
self.parent.assertListEqual(list(result["start_logits"].shape), [self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(list(result["end_logits"].shape), [self.batch_size, self.seq_length])
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
) = config_and_inputs
|
||||
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFMobileBertModelTest.TFMobileBertModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=MobileBertConfig, hidden_size=37)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_mobilebert_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_mobilebert_model(*config_and_inputs)
|
||||
|
||||
def test_for_masked_lm(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_mobilebert_for_masked_lm(*config_and_inputs)
|
||||
|
||||
def test_for_multiple_choice(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_mobilebert_for_multiple_choice(*config_and_inputs)
|
||||
|
||||
def test_for_next_sequence_prediction(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_mobilebert_for_next_sequence_prediction(*config_and_inputs)
|
||||
|
||||
def test_for_pretraining(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_mobilebert_for_pretraining(*config_and_inputs)
|
||||
|
||||
def test_for_question_answering(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_mobilebert_for_question_answering(*config_and_inputs)
|
||||
|
||||
def test_for_sequence_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_mobilebert_for_sequence_classification(*config_and_inputs)
|
||||
|
||||
def test_for_token_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_mobilebert_for_token_classification(*config_and_inputs)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
# for model_name in TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
for model_name in ["mobilebert-uncased"]:
|
||||
model = TFMobileBertModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
Loading…
Reference in New Issue
Block a user