From 9a3f91088c87b8002df9496ae4300fc78289c983 Mon Sep 17 00:00:00 2001 From: Vasily Shamporov Date: Fri, 19 Jun 2020 23:38:36 +0300 Subject: [PATCH] Add MobileBert (#4901) * Add MobileBert * Quality + Conversion script * style * Update src/transformers/modeling_mobilebert.py * Links to S3 * Style * TFMobileBert Slight fixes to the pytorch MobileBert Style * MobileBertForMaskedLM (PT + TF) * MobileBertForNextSentencePrediction (PT + TF) * MobileFor{MultipleChoice, TokenClassification} (PT + TF) ss * Tests + Auto * Doc * Tests * Addressing @sgugger's comments * Adressing @patrickvonplaten's comments * Style * Style * Integration test * style * Model card Co-authored-by: Lysandre Co-authored-by: Lysandre Debut --- docs/source/index.rst | 1 + docs/source/model_doc/mobilebert.rst | 169 ++ .../google/mobilebert-uncased/README.md | 32 + src/transformers/__init__.py | 31 + src/transformers/configuration_auto.py | 2 + src/transformers/configuration_mobilebert.py | 159 ++ ...ebert_original_tf_checkpoint_to_pytorch.py | 42 + src/transformers/modeling_auto.py | 18 + src/transformers/modeling_mobilebert.py | 1614 +++++++++++++++++ src/transformers/modeling_tf_auto.py | 17 + src/transformers/modeling_tf_mobilebert.py | 1474 +++++++++++++++ src/transformers/tokenization_auto.py | 4 + src/transformers/tokenization_mobilebert.py | 69 + tests/test_modeling_mobilebert.py | 499 +++++ tests/test_modeling_tf_mobilebert.py | 321 ++++ 15 files changed, 4452 insertions(+) create mode 100644 docs/source/model_doc/mobilebert.rst create mode 100644 model_cards/google/mobilebert-uncased/README.md create mode 100644 src/transformers/configuration_mobilebert.py create mode 100644 src/transformers/convert_mobilebert_original_tf_checkpoint_to_pytorch.py create mode 100644 src/transformers/modeling_mobilebert.py create mode 100644 src/transformers/modeling_tf_mobilebert.py create mode 100644 src/transformers/tokenization_mobilebert.py create mode 100644 tests/test_modeling_mobilebert.py create mode 100644 tests/test_modeling_tf_mobilebert.py diff --git a/docs/source/index.rst b/docs/source/index.rst index e1f5902861e..b84276ec059 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -184,3 +184,4 @@ conversion utilities for the following models: model_doc/marian model_doc/longformer model_doc/retribert + model_doc/mobilebert diff --git a/docs/source/model_doc/mobilebert.rst b/docs/source/model_doc/mobilebert.rst new file mode 100644 index 00000000000..b8f0bba1445 --- /dev/null +++ b/docs/source/model_doc/mobilebert.rst @@ -0,0 +1,169 @@ +MobileBERT +---------------------------------------------------- + +Overview +~~~~~~~~~~~~~~~~~~~~~ + +The MobileBERT model was proposed in `MobileBERT: a Compact Task-Agnostic BERT +for Resource-Limited Devices `__ +by Zhiqing Sun, Hongkun Yu, Xiaodan Song, Renjie Liu, Yiming Yang, and Denny Zhou. It's a bidirectional transformer +based on the BERT model, which is compressed and accelerated using several approaches. + +The abstract from the paper is the following: + +*Natural Language Processing (NLP) has recently achieved great success by using huge pre-trained models with hundreds +of millions of parameters. However, these models suffer from heavy model sizes and high latency such that they cannot +be deployed to resource-limited mobile devices. In this paper, we propose MobileBERT for compressing and accelerating +the popular BERT model. Like the original BERT, MobileBERT is task-agnostic, that is, it can be generically applied +to various downstream NLP tasks via simple fine-tuning. Basically, MobileBERT is a thin version of BERT_LARGE, while +equipped with bottleneck structures and a carefully designed balance between self-attentions and feed-forward +networks. To train MobileBERT, we first train a specially designed teacher model, an inverted-bottleneck incorporated +BERT_LARGE model. Then, we conduct knowledge transfer from this teacher to MobileBERT. Empirical studies show that +MobileBERT is 4.3x smaller and 5.5x faster than BERT_BASE while achieving competitive results on well-known +benchmarks. On the natural language inference tasks of GLUE, MobileBERT achieves a GLUEscore o 77.7 +(0.6 lower than BERT_BASE), and 62 ms latency on a Pixel 4 phone. On the SQuAD v1.1/v2.0 question answering task, +MobileBERT achieves a dev F1 score of 90.0/79.2 (1.5/2.1 higher than BERT_BASE).* + +Tips: + +- MobileBERT is a model with absolute position embeddings so it's usually advised to pad the inputs on + the right rather than the left. +- MobileBERT is similar to BERT and therefore relies on the masked language modeling (MLM) objective. + It is therefore efficient at predicting masked tokens and at NLU in general, but is not optimal for + text generation. Models trained with a causal language modeling (CLM) objective are better in that regard. + +The original code can be found `here `_. + +MobileBertConfig +~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.MobileBertConfig + :members: + + +MobileBertTokenizer +~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.MobileBertTokenizer + :members: build_inputs_with_special_tokens, get_special_tokens_mask, + create_token_type_ids_from_sequences, save_vocabulary + + +MobileBertTokenizerFast +~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.MobileBertTokenizerFast + :members: + + +MobileBertModel +~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.MobileBertModel + :members: + + +MobileBertForPreTraining +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.MobileBertForPreTraining + :members: + + +MobileBertForMaskedLM +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.MobileBertForMaskedLM + :members: + + +MobileBertForNextSentencePrediction +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.MobileBertForNextSentencePrediction + :members: + + +MobileBertForSequenceClassification +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.MobileBertForSequenceClassification + :members: + + +MobileBertForMultipleChoice +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.MobileBertForMultipleChoice + :members: + + +MobileBertForTokenClassification +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.MobileBertForTokenClassification + :members: + + +MobileBertForQuestionAnswering +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.MobileBertForQuestionAnswering + :members: + + +TFMobileBertModel +~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.TFMobileBertModel + :members: + + +TFMobileBertForPreTraining +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.TFMobileBertForPreTraining + :members: + + +TFMobileBertForMaskedLM +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.TFMobileBertForMaskedLM + :members: + + +TFMobileBertForNextSentencePrediction +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.TFMobileBertForNextSentencePrediction + :members: + + +TFMobileBertForSequenceClassification +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.TFMobileBertForSequenceClassification + :members: + + +TFMobileBertForMultipleChoice +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.TFMobileBertForMultipleChoice + :members: + + +TFMobileBertForTokenClassification +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.TFMobileBertForTokenClassification + :members: + + +TFMobileBertForQuestionAnswering +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.TFMobileBertForQuestionAnswering + :members: + diff --git a/model_cards/google/mobilebert-uncased/README.md b/model_cards/google/mobilebert-uncased/README.md new file mode 100644 index 00000000000..b36556015c4 --- /dev/null +++ b/model_cards/google/mobilebert-uncased/README.md @@ -0,0 +1,32 @@ +--- +language: english +thumbnail: https://huggingface.co/front/thumbnails/google.png + +license: apache-2.0 +--- + +## MobileBERT: a Compact Task-Agnostic BERT for Resource-Limited Devices + +MobileBERT is a thin version of BERT_LARGE, while equipped with bottleneck structures and a carefully designed balance +between self-attentions and feed-forward networks. + +This checkpoint is the original MobileBert Optimized Uncased English: +[uncased_L-24_H-128_B-512_A-4_F-4_OPT](https://storage.googleapis.com/cloud-tpu-checkpoints/mobilebert/uncased_L-24_H-128_B-512_A-4_F-4_OPT.tar.gz) +checkpoint. + +## How to use MobileBERT in `transformers` + +```python +from transformers import pipeline + +fill_mask = pipeline( + "fill-mask", + model="google/mobilebert-uncased", + tokenizer="google/mobilebert-uncased" +) + +print( + fill_mask(f"HuggingFace is creating a {fill_mask.tokenizer.mask_token} that the community uses to solve NLP tasks.") +) + +``` diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 55db51d53bb..dfe12b8bd9c 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -34,6 +34,7 @@ from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config from .configuration_longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig from .configuration_marian import MarianConfig from .configuration_mmbt import MMBTConfig +from .configuration_mobilebert import MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, MobileBertConfig from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig from .configuration_reformer import REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, ReformerConfig from .configuration_retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig @@ -129,6 +130,7 @@ from .tokenization_electra import ElectraTokenizer, ElectraTokenizerFast from .tokenization_flaubert import FlaubertTokenizer from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast from .tokenization_longformer import LongformerTokenizer, LongformerTokenizerFast +from .tokenization_mobilebert import MobileBertTokenizer, MobileBertTokenizerFast from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast from .tokenization_reformer import ReformerTokenizer from .tokenization_retribert import RetriBertTokenizer, RetriBertTokenizerFast @@ -188,6 +190,21 @@ if is_torch_available(): MODEL_FOR_MULTIPLE_CHOICE_MAPPING, ) + from .modeling_mobilebert import ( + MobileBertPreTrainedModel, + MobileBertModel, + MobileBertForPreTraining, + MobileBertForSequenceClassification, + MobileBertForQuestionAnswering, + MobileBertForMaskedLM, + MobileBertForNextSentencePrediction, + MobileBertForMultipleChoice, + MobileBertForTokenClassification, + load_tf_weights_in_mobilebert, + MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST, + MobileBertLayer, + ) + from .modeling_bert import ( BertPreTrainedModel, BertModel, @@ -495,6 +512,20 @@ if is_tf_available(): TFGPT2PreTrainedModel, ) + from .modeling_tf_mobilebert import ( + TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST, + TFMobileBertModel, + TFMobileBertPreTrainedModel, + TFMobileBertForPreTraining, + TFMobileBertForSequenceClassification, + TFMobileBertForQuestionAnswering, + TFMobileBertForMaskedLM, + TFMobileBertForNextSentencePrediction, + TFMobileBertForMultipleChoice, + TFMobileBertForTokenClassification, + TFMobileBertMainLayer, + ) + from .modeling_tf_openai import ( TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST, TFOpenAIGPTDoubleHeadsModel, diff --git a/src/transformers/configuration_auto.py b/src/transformers/configuration_auto.py index 09a76abe261..64d2324bce2 100644 --- a/src/transformers/configuration_auto.py +++ b/src/transformers/configuration_auto.py @@ -30,6 +30,7 @@ from .configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, Flau from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config from .configuration_longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig from .configuration_marian import MarianConfig +from .configuration_mobilebert import MobileBertConfig from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig from .configuration_reformer import ReformerConfig from .configuration_retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig @@ -75,6 +76,7 @@ CONFIG_MAPPING = OrderedDict( [ ("retribert", RetriBertConfig,), ("t5", T5Config,), + ("mobilebert", MobileBertConfig,), ("distilbert", DistilBertConfig,), ("albert", AlbertConfig,), ("camembert", CamembertConfig,), diff --git a/src/transformers/configuration_mobilebert.py b/src/transformers/configuration_mobilebert.py new file mode 100644 index 00000000000..cfb16baf2c3 --- /dev/null +++ b/src/transformers/configuration_mobilebert.py @@ -0,0 +1,159 @@ +# coding=utf-8 +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" MobileBERT model configuration """ + +import logging + +from .configuration_utils import PretrainedConfig + + +logger = logging.getLogger(__name__) + +MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "mobilebert-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/google/mobilebert-uncased/config.json" +} + + +class MobileBertConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a :class:`~transformers.MobileBertModel`. + It is used to instantiate a MobileBERT model according to the specified arguments, defining the model + architecture. + + Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used + to control the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig` + for more information. + + + Args: + vocab_size (:obj:`int`, optional, defaults to 30522): + Vocabulary size of the MobileBERT model. Defines the different tokens that + can be represented by the `inputs_ids` passed to the forward method of :class:`~transformers.MobileBertModel`. + hidden_size (:obj:`int`, optional, defaults to 512): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (:obj:`int`, optional, defaults to 24): + Number of hidden layers in the Transformer encoder. + num_attention_heads (:obj:`int`, optional, defaults to 4): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (:obj:`int`, optional, defaults to 512): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (:obj:`str` or :obj:`function`, optional, defaults to "relu"): + The non-linear activation function (function or string) in the encoder and pooler. + If string, "gelu", "relu", "swish" and "gelu_new" are supported. + hidden_dropout_prob (:obj:`float`, optional, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (:obj:`float`, optional, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (:obj:`int`, optional, defaults to 512): + The maximum sequence length that this model might ever be used with. + Typically set this to something large just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (:obj:`int`, optional, defaults to 2): + The vocabulary size of the `token_type_ids` passed into :class:`~transformers.MobileBertModel`. + initializer_range (:obj:`float`, optional, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (:obj:`float`, optional, defaults to 1e-12): + The epsilon used by the layer normalization layers. + + pad_token_id (:obj:`int`, optional, defaults to 0): + The ID of the token in the word embedding to use as padding. + embedding_size (:obj:`int`, optional, defaults to 128): + The dimension of the word embedding vectors. + trigram_input (:obj:`bool`, optional, defaults to True): + Use a convolution of trigram as input. + use_bottleneck (:obj:`bool`, optional, defaults to True): + Whether to use bottleneck in BERT. + intra_bottleneck_size (:obj:`int`, optional, defaults to 128): + Size of bottleneck layer output. + use_bottleneck_attention (:obj:`bool`, optional, defaults to False): + Whether to use attention inputs from the bottleneck transformation. + key_query_shared_bottleneck (:obj:`bool`, optional, defaults to True): + Whether to use the same linear transformation for query&key in the bottleneck. + num_feedforward_networks (:obj:`int`, optional, defaults to 4): + Number of FFNs in a block. + normalization_type (:obj:`str`, optional, defaults to "no_norm"): + The normalization type in BERT. + + Example: + + from transformers import MobileBertModel, MobileBertConfig + + # Initializing a MobileBERT configuration + configuration = MobileBertConfig() + + # Initializing a model from the configuration above + model = MobileBertModel(configuration) + + # Accessing the model configuration + configuration = model.config + + Attributes: + pretrained_config_archive_map (Dict[str, str]): + A dictionary containing all the available pre-trained checkpoints. + """ + pretrained_config_archive_map = MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP + model_type = "mobilebert" + + def __init__( + self, + vocab_size=30522, + hidden_size=512, + num_hidden_layers=24, + num_attention_heads=4, + intermediate_size=512, + hidden_act="relu", + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + layer_norm_eps=1e-12, + pad_token_id=0, + embedding_size=128, + trigram_input=True, + use_bottleneck=True, + intra_bottleneck_size=128, + use_bottleneck_attention=False, + key_query_shared_bottleneck=True, + num_feedforward_networks=4, + normalization_type="no_norm", + classifier_activation=True, + **kwargs + ): + super().__init__(pad_token_id=pad_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.embedding_size = embedding_size + self.trigram_input = trigram_input + self.use_bottleneck = use_bottleneck + self.intra_bottleneck_size = intra_bottleneck_size + self.use_bottleneck_attention = use_bottleneck_attention + self.key_query_shared_bottleneck = key_query_shared_bottleneck + self.num_feedforward_networks = num_feedforward_networks + self.normalization_type = normalization_type + self.classifier_activation = classifier_activation + + if self.use_bottleneck: + self.true_hidden_size = intra_bottleneck_size + else: + self.true_hidden_size = hidden_size diff --git a/src/transformers/convert_mobilebert_original_tf_checkpoint_to_pytorch.py b/src/transformers/convert_mobilebert_original_tf_checkpoint_to_pytorch.py new file mode 100644 index 00000000000..9651069baaf --- /dev/null +++ b/src/transformers/convert_mobilebert_original_tf_checkpoint_to_pytorch.py @@ -0,0 +1,42 @@ +import argparse +import logging + +import torch + +from transformers import MobileBertConfig, MobileBertForPreTraining, load_tf_weights_in_mobilebert + + +logging.basicConfig(level=logging.INFO) + + +def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, mobilebert_config_file, pytorch_dump_path): + # Initialise PyTorch model + config = MobileBertConfig.from_json_file(mobilebert_config_file) + print("Building PyTorch model from configuration: {}".format(str(config))) + model = MobileBertForPreTraining(config) + # Load weights from tf checkpoint + model = load_tf_weights_in_mobilebert(model, config, tf_checkpoint_path) + # Save pytorch-model + print("Save PyTorch model to {}".format(pytorch_dump_path)) + torch.save(model.state_dict(), pytorch_dump_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." + ) + parser.add_argument( + "--mobilebert_config_file", + default=None, + type=str, + required=True, + help="The config json file corresponding to the pre-trained MobileBERT model. \n" + "This specifies the model architecture.", + ) + parser.add_argument( + "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." + ) + args = parser.parse_args() + convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.mobilebert_config_file, args.pytorch_dump_path) diff --git a/src/transformers/modeling_auto.py b/src/transformers/modeling_auto.py index 26d0f7d5d72..e06b6b8a077 100644 --- a/src/transformers/modeling_auto.py +++ b/src/transformers/modeling_auto.py @@ -32,6 +32,7 @@ from .configuration_auto import ( FlaubertConfig, GPT2Config, LongformerConfig, + MobileBertConfig, OpenAIGPTConfig, ReformerConfig, RetriBertConfig, @@ -111,6 +112,15 @@ from .modeling_longformer import ( LongformerModel, ) from .modeling_marian import MarianMTModel +from .modeling_mobilebert import ( + MobileBertForMaskedLM, + MobileBertForMultipleChoice, + MobileBertForPreTraining, + MobileBertForQuestionAnswering, + MobileBertForSequenceClassification, + MobileBertForTokenClassification, + MobileBertModel, +) from .modeling_openai import OpenAIGPTLMHeadModel, OpenAIGPTModel from .modeling_reformer import ReformerModel, ReformerModelWithLMHead from .modeling_retribert import RetriBertModel @@ -166,6 +176,7 @@ MODEL_MAPPING = OrderedDict( (BertConfig, BertModel), (OpenAIGPTConfig, OpenAIGPTModel), (GPT2Config, GPT2Model), + (MobileBertConfig, MobileBertModel), (TransfoXLConfig, TransfoXLModel), (XLNetConfig, XLNetModel), (FlaubertConfig, FlaubertModel), @@ -190,6 +201,7 @@ MODEL_FOR_PRETRAINING_MAPPING = OrderedDict( (BertConfig, BertForPreTraining), (OpenAIGPTConfig, OpenAIGPTLMHeadModel), (GPT2Config, GPT2LMHeadModel), + (MobileBertConfig, MobileBertForPreTraining), (TransfoXLConfig, TransfoXLLMHeadModel), (XLNetConfig, XLNetLMHeadModel), (FlaubertConfig, FlaubertWithLMHeadModel), @@ -213,6 +225,7 @@ MODEL_WITH_LM_HEAD_MAPPING = OrderedDict( (BertConfig, BertForMaskedLM), (OpenAIGPTConfig, OpenAIGPTLMHeadModel), (GPT2Config, GPT2LMHeadModel), + (MobileBertConfig, MobileBertForMaskedLM), (TransfoXLConfig, TransfoXLLMHeadModel), (XLNetConfig, XLNetLMHeadModel), (FlaubertConfig, FlaubertWithLMHeadModel), @@ -249,6 +262,7 @@ MODEL_FOR_MASKED_LM_MAPPING = OrderedDict( (LongformerConfig, LongformerForMaskedLM), (RobertaConfig, RobertaForMaskedLM), (BertConfig, BertForMaskedLM), + (MobileBertConfig, MobileBertForMaskedLM), (FlaubertConfig, FlaubertWithLMHeadModel), (XLMConfig, XLMWithLMHeadModel), (ElectraConfig, ElectraForMaskedLM), @@ -275,6 +289,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict( (RobertaConfig, RobertaForSequenceClassification), (BertConfig, BertForSequenceClassification), (XLNetConfig, XLNetForSequenceClassification), + (MobileBertConfig, MobileBertForSequenceClassification), (FlaubertConfig, FlaubertForSequenceClassification), (XLMConfig, XLMForSequenceClassification), (ElectraConfig, ElectraForSequenceClassification), @@ -292,6 +307,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict( (BertConfig, BertForQuestionAnswering), (XLNetConfig, XLNetForQuestionAnsweringSimple), (FlaubertConfig, FlaubertForQuestionAnsweringSimple), + (MobileBertConfig, MobileBertForQuestionAnswering), (XLMConfig, XLMForQuestionAnsweringSimple), (ElectraConfig, ElectraForQuestionAnswering), ] @@ -306,6 +322,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict( (LongformerConfig, LongformerForTokenClassification), (RobertaConfig, RobertaForTokenClassification), (BertConfig, BertForTokenClassification), + (MobileBertConfig, MobileBertForTokenClassification), (XLNetConfig, XLNetForTokenClassification), (AlbertConfig, AlbertForTokenClassification), (ElectraConfig, ElectraForTokenClassification), @@ -322,6 +339,7 @@ MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict( (RobertaConfig, RobertaForMultipleChoice), (BertConfig, BertForMultipleChoice), (DistilBertConfig, DistilBertForMultipleChoice), + (MobileBertConfig, MobileBertForMultipleChoice), (XLNetConfig, XLNetForMultipleChoice), (AlbertConfig, AlbertForMultipleChoice), ] diff --git a/src/transformers/modeling_mobilebert.py b/src/transformers/modeling_mobilebert.py new file mode 100644 index 00000000000..48e7f875dda --- /dev/null +++ b/src/transformers/modeling_mobilebert.py @@ -0,0 +1,1614 @@ +# MIT License +# +# Copyright (c) 2020 The Google AI Language Team Authors, The HuggingFace Inc. team and github/lonePatient +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import logging +import math +import os +import warnings + +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn import CrossEntropyLoss, MSELoss + +from transformers.modeling_bert import BertIntermediate + +from .activations import gelu, gelu_new, swish +from .configuration_mobilebert import MobileBertConfig +from .file_utils import add_start_docstrings, add_start_docstrings_to_callable +from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer + + +logger = logging.getLogger(__name__) +MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST = ["mobilebert-uncased"] + + +def load_tf_weights_in_mobilebert(model, config, tf_checkpoint_path): + """ Load tf checkpoints in a pytorch model. + """ + try: + import re + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info("Converting TensorFlow checkpoint from {}".format(tf_path)) + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info("Loading TF weight {} with shape {}".format(name, shape)) + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array) + + for name, array in zip(names, arrays): + name = name.replace("ffn_layer", "ffn") + name = name.replace("FakeLayerNorm", "LayerNorm") + name = name.replace("extra_output_weights", "dense/kernel") + name = name.replace("bert", "mobilebert") + name = name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name + ): + logger.info("Skipping {}".format("/".join(name))) + continue + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "kernel" or scope_names[0] == "gamma": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "output_weights": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info("Skipping {}".format("/".join(name))) + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if m_name[-11:] == "_embeddings": + pointer = getattr(pointer, "weight") + elif m_name == "kernel": + array = np.transpose(array) + try: + assert pointer.shape == array.shape + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info("Initialize PyTorch weight {}".format(name)) + pointer.data = torch.from_numpy(array) + return model + + +def mish(x): + return x * torch.tanh(nn.functional.softplus(x)) + + +class NoNorm(nn.Module): + def __init__(self, feat_size, eps=None): + super().__init__() + self.bias = nn.Parameter(torch.zeros(feat_size)) + self.weight = nn.Parameter(torch.ones(feat_size)) + + def forward(self, input_tensor): + return input_tensor * self.weight + self.bias + + +ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish, "gelu_new": gelu_new, "mish": mish} +NORM2FN = {"layer_norm": torch.nn.LayerNorm, "no_norm": NoNorm} + + +class MobileBertEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings. + """ + + def __init__(self, config): + super().__init__() + self.trigram_input = config.trigram_input + self.embedding_size = config.embedding_size + self.hidden_size = config.hidden_size + + self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + embed_dim_multiplier = 3 if self.trigram_input else 1 + embedded_input_size = self.embedding_size * embed_dim_multiplier + self.embedding_transformation = nn.Linear(embedded_input_size, config.hidden_size) + + self.LayerNorm = NORM2FN[config.normalization_type](config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + seq_length = input_shape[1] + device = input_ids.device if input_ids is not None else inputs_embeds.device + if position_ids is None: + position_ids = torch.arange(seq_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).expand(input_shape) + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + if self.trigram_input: + # From the paper MobileBERT: a Compact Task-Agnostic BERT for Resource-Limited + # Devices (https://arxiv.org/abs/2004.02984) + # + # The embedding table in BERT models accounts for a substantial proportion of model size. To compress + # the embedding layer, we reduce the embedding dimension to 128 in MobileBERT. + # Then, we apply a 1D convolution with kernel size 3 on the raw token embedding to produce a 512 + # dimensional output. + inputs_embeds = torch.cat( + [ + F.pad(inputs_embeds[:, 1:], [0, 0, 0, 1, 0, 0], value=0), + inputs_embeds, + F.pad(inputs_embeds[:, :-1], [0, 0, 1, 0, 0, 0], value=0), + ], + dim=2, + ) + if self.trigram_input or self.embedding_size != self.hidden_size: + inputs_embeds = self.embedding_transformation(inputs_embeds) + + # Add positional embeddings and token type embeddings, then layer + # normalize and perform dropout. + position_embeddings = self.position_embeddings(position_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + embeddings = inputs_embeds + position_embeddings + token_type_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class MobileBertSelfAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.true_hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.true_hidden_size, self.all_head_size) + self.key = nn.Linear(config.true_hidden_size, self.all_head_size) + self.value = nn.Linear( + config.true_hidden_size if config.use_bottleneck_attention else config.hidden_size, self.all_head_size + ) + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + query_tensor, + key_tensor, + value_tensor, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + output_attentions=None, + ): + mixed_query_layer = self.query(query_tensor) + mixed_key_layer = self.key(key_tensor) + mixed_value_layer = self.value(value_tensor) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + return outputs + + +class MobileBertSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.use_bottleneck = config.use_bottleneck + self.dense = nn.Linear(config.true_hidden_size, config.true_hidden_size) + self.LayerNorm = NORM2FN[config.normalization_type](config.true_hidden_size, eps=config.layer_norm_eps) + if not self.use_bottleneck: + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, residual_tensor): + layer_outputs = self.dense(hidden_states) + if not self.use_bottleneck: + layer_outputs = self.dropout(layer_outputs) + layer_outputs = self.LayerNorm(layer_outputs + residual_tensor) + return layer_outputs + + +class MobileBertAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.self = MobileBertSelfAttention(config) + self.output = MobileBertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + query_tensor, + key_tensor, + value_tensor, + layer_input, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + output_attentions=None, + ): + self_outputs = self.self( + query_tensor, + key_tensor, + value_tensor, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions, + ) + # Run a linear projection of `hidden_size` then add a residual + # with `layer_input`. + attention_output = self.output(self_outputs[0], layer_input) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class MobileBertIntermediate(BertIntermediate): + def __init__(self, config): + super().__init__(config) + self.dense = nn.Linear(config.true_hidden_size, config.intermediate_size) + + +class OutputBottleneck(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.true_hidden_size, config.hidden_size) + self.LayerNorm = NORM2FN[config.normalization_type](config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, residual_tensor): + layer_outputs = self.dense(hidden_states) + layer_outputs = self.dropout(layer_outputs) + layer_outputs = self.LayerNorm(layer_outputs + residual_tensor) + return layer_outputs + + +class MobileBertOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.use_bottleneck = config.use_bottleneck + self.dense = nn.Linear(config.intermediate_size, config.true_hidden_size) + self.LayerNorm = NORM2FN[config.normalization_type](config.true_hidden_size) + if not self.use_bottleneck: + self.dropout = nn.Dropout(config.hidden_dropout_prob) + else: + self.bottleneck = OutputBottleneck(config) + + def forward(self, intermediate_states, residual_tensor_1, residual_tensor_2): + layer_output = self.dense(intermediate_states) + if not self.use_bottleneck: + layer_output = self.dropout(layer_output) + layer_output = self.LayerNorm(layer_output + residual_tensor_1) + else: + layer_output = self.LayerNorm(layer_output + residual_tensor_1) + layer_output = self.bottleneck(layer_output, residual_tensor_2) + return layer_output + + +class BottleneckLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intra_bottleneck_size) + self.LayerNorm = NORM2FN[config.normalization_type](config.intra_bottleneck_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + layer_input = self.dense(hidden_states) + layer_input = self.LayerNorm(layer_input) + return layer_input + + +class Bottleneck(nn.Module): + def __init__(self, config): + super().__init__() + self.key_query_shared_bottleneck = config.key_query_shared_bottleneck + self.use_bottleneck_attention = config.use_bottleneck_attention + self.input = BottleneckLayer(config) + if self.key_query_shared_bottleneck: + self.attention = BottleneckLayer(config) + + def forward(self, hidden_states): + # This method can return three different tuples of values. These different values make use of bottlenecks, + # which are linear layers used to project the hidden states to a lower-dimensional vector, reducing memory + # usage. These linear layer have weights that are learned during training. + # + # If `config.use_bottleneck_attention`, it will return the result of the bottleneck layer four times for the + # key, query, value, and "layer input" to be used by the attention layer. + # This bottleneck is used to project the hidden. This last layer input will be used as a residual tensor + # in the attention self output, after the attention scores have been computed. + # + # If not `config.use_bottleneck_attention` and `config.key_query_shared_bottleneck`, this will return + # four values, three of which have been passed through a bottleneck: the query and key, passed through the same + # bottleneck, and the residual layer to be applied in the attention self output, through another bottleneck. + # + # Finally, in the last case, the values for the query, key and values are the hidden states without bottleneck, + # and the residual layer will be this value passed through a bottleneck. + + bottlenecked_hidden_states = self.input(hidden_states) + if self.use_bottleneck_attention: + return (bottlenecked_hidden_states,) * 4 + elif self.key_query_shared_bottleneck: + shared_attention_input = self.attention(hidden_states) + return (shared_attention_input, shared_attention_input, hidden_states, bottlenecked_hidden_states) + else: + return (hidden_states, hidden_states, hidden_states, bottlenecked_hidden_states) + + +class FFNOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.true_hidden_size) + self.LayerNorm = NORM2FN[config.normalization_type](config.true_hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states, residual_tensor): + layer_outputs = self.dense(hidden_states) + layer_outputs = self.LayerNorm(layer_outputs + residual_tensor) + return layer_outputs + + +class FFNLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.intermediate = MobileBertIntermediate(config) + self.output = FFNOutput(config) + + def forward(self, hidden_states): + intermediate_output = self.intermediate(hidden_states) + layer_outputs = self.output(intermediate_output, hidden_states) + return layer_outputs + + +class MobileBertLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.use_bottleneck = config.use_bottleneck + self.num_feedforward_networks = config.num_feedforward_networks + + self.attention = MobileBertAttention(config) + self.intermediate = MobileBertIntermediate(config) + self.output = MobileBertOutput(config) + if self.use_bottleneck: + self.bottleneck = Bottleneck(config) + if config.num_feedforward_networks > 1: + self.ffn = nn.ModuleList([FFNLayer(config) for _ in range(config.num_feedforward_networks - 1)]) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + output_attentions=None, + ): + if self.use_bottleneck: + query_tensor, key_tensor, value_tensor, layer_input = self.bottleneck(hidden_states) + else: + query_tensor, key_tensor, value_tensor, layer_input = [hidden_states] * 4 + + self_attention_outputs = self.attention( + query_tensor, + key_tensor, + value_tensor, + layer_input, + attention_mask, + head_mask, + output_attentions=output_attentions, + ) + attention_output = self_attention_outputs[0] + s = (attention_output,) + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + if self.num_feedforward_networks != 1: + for i, ffn_module in enumerate(self.ffn): + attention_output = ffn_module(attention_output) + s += (attention_output,) + + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output, hidden_states) + outputs = ( + (layer_output,) + + outputs + + ( + torch.tensor(1000), + query_tensor, + key_tensor, + value_tensor, + layer_input, + attention_output, + intermediate_output, + ) + + s + ) + return outputs + + +class MobileBertEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.output_hidden_states = config.output_hidden_states + self.layer = nn.ModuleList([MobileBertLayer(config) for _ in range(config.num_hidden_layers)]) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + output_attentions=False, + ): + all_hidden_states = () + all_attentions = () + for i, layer_module in enumerate(self.layer): + if self.output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_outputs = layer_module( + hidden_states, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + output_attentions, + ) + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + # Add last layer + if self.output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + outputs = (hidden_states,) + if self.output_hidden_states: + outputs = outputs + (all_hidden_states,) + if output_attentions: + outputs = outputs + (all_attentions,) + return outputs # last-layer hidden state, (all hidden states), (all attentions) + + +class MobileBertPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.do_activate = config.classifier_activation + if self.do_activate: + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + if not self.do_activate: + return first_token_tensor + else: + pooled_output = self.dense(first_token_tensor) + pooled_output = F.tanh(pooled_output) + return pooled_output + + +class MobileBertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = NORM2FN["layer_norm"](config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class MobileBertLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = MobileBertPredictionHeadTransform(config) + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.dense = nn.Linear(config.vocab_size, config.hidden_size - config.embedding_size, bias=False) + self.decoder = nn.Linear(config.embedding_size, config.vocab_size, bias=False) + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = hidden_states.matmul(torch.cat([self.decoder.weight.t(), self.dense.weight], dim=0)) + hidden_states += self.bias + return hidden_states + + +class MobileBertOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = MobileBertLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class MobileBertPreTrainingHeads(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = MobileBertLMPredictionHead(config) + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, sequence_output, pooled_output): + prediction_scores = self.predictions(sequence_output) + seq_relationship_score = self.seq_relationship(pooled_output) + return prediction_scores, seq_relationship_score + + +class MobileBertPreTrainedModel(PreTrainedModel): + """ An abstract class to handle weights initialization and + a simple interface for downloading and loading pretrained models. + """ + + config_class = MobileBertConfig + pretrained_model_archive_map = MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST + load_tf_weights = load_tf_weights_in_mobilebert + base_model_prefix = "mobilebert" + + def _init_weights(self, module): + """ Initialize the weights """ + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, (nn.LayerNorm, NoNorm)): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +MOBILEBERT_START_DOCSTRING = r""" + This model is a PyTorch `torch.nn.Module `_ sub-class. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general + usage and behavior. + + Parameters: + config (:class:`~transformers.MobileBertConfig`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the configuration. + Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. +""" + +MOBILEBERT_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using :class:`transformers.MobileBertTokenizer`. + See :func:`transformers.PreTrainedTokenizer.encode` and + :func:`transformers.PreTrainedTokenizer.encode_plus` for details. + + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + Mask to avoid performing attention on padding token indices. + Mask values selected in ``[0, 1]``: + ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. + + `What are attention masks? <../glossary.html#attention-mask>`__ + token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + Segment token indices to indicate first and second portions of the inputs. + Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1`` + corresponds to a `sentence B` token + + `What are token type IDs? <../glossary.html#token-type-ids>`_ + position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + Indices of positions of each input sequence tokens in the position embeddings. + Selected in the range ``[0, config.max_position_embeddings - 1]``. + + `What are position IDs? <../glossary.html#position-ids>`_ + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`): + Mask to nullify selected heads of the self-attention modules. + Mask values selected in ``[0, 1]``: + :obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**. + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + if the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask + is used in the cross-attention if the model is configured as a decoder. + Mask values selected in ``[0, 1]``: + ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. +""" + + +@add_start_docstrings( + "The bare MobileBert Model transformer outputting raw hidden-states without any specific head on top.", + MOBILEBERT_START_DOCSTRING, +) +class MobileBertModel(MobileBertPreTrainedModel): + """ + https://arxiv.org/pdf/2004.02984.pdf + """ + + def __init__(self, config): + super().__init__(config) + self.config = config + self.embeddings = MobileBertEmbeddings(config) + self.encoder = MobileBertEncoder(config) + self.pooler = MobileBertPooler(config) + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ Prunes heads of the model. + heads_to_prune: dict of {layer_num: list of heads to prune in this layer} + See base class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + output_attentions=None, + ): + r""" + Return: + :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.MobileBertConfig`) and inputs: + last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (:obj:`torch.FloatTensor`: of shape :obj:`(batch_size, hidden_size)`): + Last layer hidden-state of the first token of the sequence (classification token) + further processed by a Linear layer and a Tanh activation function. The Linear + layer weights are trained from the next sentence prediction (classification) + objective during pre-training. + + This output is usually *not* a good summary + of the semantic content of the input, you're often better with averaging or pooling + the sequence of hidden-states for the whole input sequence. + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + + Examples:: + + from transformers import MobileBertModel, MobileBertTokenizer + import torch + + tokenizer = MobileBertTokenizer.from_pretrained(model_name_or_path) + model = MobileBertModel.from_pretrained(model_name_or_path) + + input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1 + outputs = model(input_ids) + + last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple + + """ + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if attention_mask is None: + attention_mask = torch.ones(input_shape, device=device) + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask( + attention_mask, input_shape, self.device + ) + + # If a 2D ou 3D attention mask is provided for the cross-attention + # we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings( + input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + output_attentions=output_attentions, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) + outputs = (sequence_output, pooled_output,) + encoder_outputs[ + 1: + ] # add hidden_states and attentions if they are here + return outputs # sequence_output, pooled_output, (hidden_states), (attentions) + + +@add_start_docstrings( + """MobileBert Model with two heads on top as done during the pre-training: a `masked language modeling` head and + a `next sentence prediction (classification)` head. """, + MOBILEBERT_START_DOCSTRING, +) +class MobileBertForPreTraining(MobileBertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.mobilebert = MobileBertModel(config) + self.cls = MobileBertPreTrainingHeads(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def tie_weights(self): + """ + Tie the weights between the input embeddings and the output embeddings. + If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning + the weights instead. + """ + output_embeddings = self.get_output_embeddings() + input_embeddings = self.get_input_embeddings() + + resized_dense = nn.Linear( + input_embeddings.num_embeddings, self.config.hidden_size - self.config.embedding_size, bias=False + ) + kept_data = self.cls.predictions.dense.weight.data[ + ..., : min(self.cls.predictions.dense.weight.data.shape[1], resized_dense.weight.data.shape[1]) + ] + resized_dense.weight.data[..., : self.cls.predictions.dense.weight.data.shape[1]] = kept_data + self.cls.predictions.dense = resized_dense + + if output_embeddings is not None: + self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings()) + + @add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + next_sentence_label=None, + output_attentions=None, + ): + r""" + labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`, defaults to :obj:`None`): + Labels for computing the masked language modeling loss. + Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) + Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels + in ``[0, ..., config.vocab_size]`` + next_sentence_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`, defaults to :obj:`None`): + Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see :obj:`input_ids` docstring) + Indices should be in ``[0, 1]``. + ``0`` indicates sequence B is a continuation of sequence A, + ``1`` indicates sequence B is a random sequence. + Returns: + :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.MobileBertConfig`) and inputs: + loss (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: + Total loss as the sum of the masked language modeling loss and the next sequence prediction (classification) loss. + prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`) + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + seq_relationship_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`): + Prediction scores of the next sequence prediction (classification) head (scores of True/False + continuation before SoftMax). + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + + Examples:: + from transformers import MobileBertTokenizer, MobileBertForPreTraining + import torch + tokenizer = MobileBertTokenizer.from_pretrained(model_name_or_path) + model = MobileBertForPreTraining.from_pretrained(model_name_or_path) + input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1 + outputs = model(input_ids) + prediction_scores, seq_relationship_scores = outputs[:2] + + """ + outputs = self.mobilebert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + ) + sequence_output, pooled_output = outputs[:2] + prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) + outputs = (prediction_scores, seq_relationship_score,) + outputs[ + 2: + ] # add hidden states and attention if they are here + + if labels is not None and next_sentence_label is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) + total_loss = masked_lm_loss + next_sentence_loss + outputs = (total_loss,) + outputs + + return outputs # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions) + + +@add_start_docstrings("""MobileBert Model with a `language modeling` head on top. """, MOBILEBERT_START_DOCSTRING) +class MobileBertForMaskedLM(MobileBertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.mobilebert = MobileBertModel(config) + self.cls = MobileBertOnlyMLMHead(config) + self.config = config + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def tie_weights(self): + """ + Tie the weights between the input embeddings and the output embeddings. + If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning + the weights instead. + """ + output_embeddings = self.get_output_embeddings() + input_embeddings = self.get_input_embeddings() + + resized_dense = nn.Linear( + input_embeddings.num_embeddings, self.config.hidden_size - self.config.embedding_size, bias=False + ) + kept_data = self.cls.predictions.dense.weight.data[ + ..., : min(self.cls.predictions.dense.weight.data.shape[1], resized_dense.weight.data.shape[1]) + ] + resized_dense.weight.data[..., : self.cls.predictions.dense.weight.data.shape[1]] = kept_data + self.cls.predictions.dense = resized_dense + + if output_embeddings is not None: + self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings()) + + @add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + output_attentions=None, + **kwargs + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + Labels for computing the masked language modeling loss. + Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) + Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels + in ``[0, ..., config.vocab_size]`` + kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`): + Used to hide legacy arguments that have been deprecated. + + Returns: + :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.MobileBertConfig`) and inputs: + masked_lm_loss (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: + Masked language modeling loss. + prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`) + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + + Examples:: + + from transformers import MobileBertTokenizer, MobileBertForMaskedLM + import torch + + tokenizer = MobileBertTokenizer.from_pretrained('mobilebert-uncased') + model = MobileBertForMaskedLM.from_pretrained('mobilebert-uncased') + + input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1 + outputs = model(input_ids, labels=input_ids) + + loss, prediction_scores = outputs[:2] + + """ + if "masked_lm_labels" in kwargs: + warnings.warn( + "The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.", + FutureWarning, + ) + labels = kwargs.pop("masked_lm_labels") + + outputs = self.mobilebert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here + + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + outputs = (masked_lm_loss,) + outputs + + return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions) + + +class MobileBertOnlyNSPHead(nn.Module): + def __init__(self, config): + super().__init__() + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, pooled_output): + seq_relationship_score = self.seq_relationship(pooled_output) + return seq_relationship_score + + +@add_start_docstrings( + """MobileBert Model with a `next sentence prediction (classification)` head on top. """, + MOBILEBERT_START_DOCSTRING, +) +class MobileBertForNextSentencePrediction(MobileBertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.mobilebert = MobileBertModel(config) + self.cls = MobileBertOnlyNSPHead(config) + + self.init_weights() + + @add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + next_sentence_label=None, + output_attentions=None, + ): + r""" + next_sentence_label (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): + Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see ``input_ids`` docstring) + Indices should be in ``[0, 1]``. + ``0`` indicates sequence B is a continuation of sequence A, + ``1`` indicates sequence B is a random sequence. + + Returns: + :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.MobileBertConfig`) and inputs: + loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`next_sentence_label` is provided): + Next sequence prediction (classification) loss. + seq_relationship_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`): + Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax). + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + + Examples:: + + from transformers import MobileBertTokenizer, MobileBertForNextSentencePrediction + import torch + + tokenizer = MobileBertTokenizer.from_pretrained('mobilebert-uncased') + model = MobileBertForNextSentencePrediction.from_pretrained('mobilebert-uncased') + + prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." + next_sentence = "The sky is blue due to the shorter wavelength of blue light." + encoding = tokenizer.encode_plus(prompt, next_sentence, return_tensors='pt') + + loss, logits = model(**encoding, next_sentence_label=torch.LongTensor([1])) + assert logits[0, 0] < logits[0, 1] # next sentence was random + """ + + outputs = self.mobilebert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + ) + + pooled_output = outputs[1] + + seq_relationship_score = self.cls(pooled_output) + + outputs = (seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here + if next_sentence_label is not None: + loss_fct = CrossEntropyLoss() + next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) + outputs = (next_sentence_loss,) + outputs + + return outputs # (next_sentence_loss), seq_relationship_score, (hidden_states), (attentions) + + +@add_start_docstrings( + """MobileBert Model transformer with a sequence classification/regression head on top (a linear layer on top of + the pooled output) e.g. for GLUE tasks. """, + MOBILEBERT_START_DOCSTRING, +) +class MobileBertForSequenceClassification(MobileBertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.mobilebert = MobileBertModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, self.num_labels) + + self.init_weights() + + @add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): + Labels for computing the sequence classification/regression loss. + Indices should be in :obj:`[0, ..., config.num_labels - 1]`. + If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), + If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). + Returns: + :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.MobileBertConfig`) and inputs: + loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + + Examples:: + + from transformers import BertTokenizer, BertForSequenceClassification + import torch + + tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + model = BertForSequenceClassification.from_pretrained('bert-base-uncased') + + input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1 + labels = torch.tensor([1]).unsqueeze(0) # Batch size 1 + outputs = model(input_ids, labels=labels) + + loss, logits = outputs[:2] + """ + + outputs = self.mobilebert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + ) + pooled_output = outputs[1] + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here + if labels is not None: + if self.num_labels == 1: + # We are doing regression + loss_fct = MSELoss() + loss = loss_fct(logits.view(-1), labels.view(-1)) + else: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + outputs = (loss,) + outputs + return outputs # (loss), logits, (hidden_states), (attentions) + + +@add_start_docstrings( + """MobileBert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layers on top of the hidden-states output to compute `span start logits` and `span end logits`). """, + MOBILEBERT_START_DOCSTRING, +) +class MobileBertForQuestionAnswering(MobileBertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.mobilebert = MobileBertModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + self.init_weights() + + @add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + start_positions=None, + end_positions=None, + output_attentions=None, + ): + r""" + start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). + Position outside of the sequence are not taken into account for computing the loss. + end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). + Position outside of the sequence are not taken into account for computing the loss. + + Returns: + :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.MobileBertConfig`) and inputs: + loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided): + Total span extraction loss is the sum of a Cross-Entropy for the start and end positions. + start_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`): + Span-start scores (before SoftMax). + end_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`): + Span-end scores (before SoftMax). + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + + Examples:: + + from transformers import MobileBertTokenizer, MobileBertForQuestionAnswering + import torch + + tokenizer = BertTokenizer.from_pretrained(model_name_or_path) + model = MobileBertForQuestionAnswering.from_pretrained(model_name_or_path) + + question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" + encoding = tokenizer.encode_plus(question, text) + input_ids, token_type_ids = encoding["input_ids"], encoding["token_type_ids"] + start_scores, end_scores = model(torch.tensor([input_ids]), token_type_ids=torch.tensor([token_type_ids])) + + all_tokens = tokenizer.convert_ids_to_tokens(input_ids) + answer = ' '.join(all_tokens[torch.argmax(start_scores) : torch.argmax(end_scores)+1]) + + assert answer == "a nice puppet" + + """ + + outputs = self.mobilebert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + + outputs = (start_logits, end_logits,) + outputs[2:] + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions.clamp_(0, ignored_index) + end_positions.clamp_(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + outputs = (total_loss,) + outputs + + return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions) + + +@add_start_docstrings( + """MobileBert Model with a multiple choice classification head on top (a linear layer on top of + the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """, + MOBILEBERT_START_DOCSTRING, +) +class MobileBertForMultipleChoice(MobileBertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.mobilebert = MobileBertModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, 1) + + self.init_weights() + + @add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)")) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): + Labels for computing the multiple choice classification loss. + Indices should be in ``[0, ..., num_choices-1]`` where `num_choices` is the size of the second dimension + of the input tensors. (see `input_ids` above) + + Returns: + :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.MobileBertConfig`) and inputs: + loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when :obj:`labels` is provided): + Classification loss. + classification_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`): + `num_choices` is the second dimension of the input tensors. (see `input_ids` above). + + Classification scores (before SoftMax). + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + + Examples:: + + from transformers import MobileBertTokenizer, MobileBertForMultipleChoice + import torch + + tokenizer = MobileBertTokenizer.from_pretrained('mobilebert-uncased') + model = MobileBertForMultipleChoice.from_pretrained('mobilebert-uncased') + + prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." + choice0 = "It is eaten with a fork and a knife." + choice1 = "It is eaten while held in the hand." + labels = torch.tensor(0).unsqueeze(0) # choice0 is correct (according to Wikipedia ;)), batch size 1 + + encoding = tokenizer.batch_encode_plus([[prompt, choice0], [prompt, choice1]], return_tensors='pt', pad_to_max_length=True) + outputs = model(**{k: v.unsqueeze(0) for k,v in encoding.items()}, labels=labels) # batch size is 1 + + # the linear classifier still needs to be trained + loss, logits = outputs[:2] + """ + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + + input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + outputs = self.mobilebert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, num_choices) + + outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here + + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + outputs = (loss,) + outputs + + return outputs # (loss), reshaped_logits, (hidden_states), (attentions) + + +@add_start_docstrings( + """MoibleBert Model with a token classification head on top (a linear layer on top of + the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """, + MOBILEBERT_START_DOCSTRING, +) +class MobileBertForTokenClassification(MobileBertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.mobilebert = MobileBertModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + self.init_weights() + + @add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + Labels for computing the token classification loss. + Indices should be in ``[0, ..., config.num_labels - 1]``. + + Returns: + :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.MobileBertConfig`) and inputs: + loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided) : + Classification loss. + scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`) + Classification scores (before SoftMax). + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + + Examples:: + + from transformers import MobileBertTokenizer, MobileBertForTokenClassification + import torch + + tokenizer = MobileBertTokenizer.from_pretrained('mobilebert-uncased') + model = MobileBertForTokenClassification.from_pretrained('mobilebert-uncased') + + input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1 + labels = torch.tensor([1] * input_ids.size(1)).unsqueeze(0) # Batch size 1 + outputs = model(input_ids, labels=labels) + + loss, scores = outputs[:2] + + """ + + outputs = self.mobilebert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here + if labels is not None: + loss_fct = CrossEntropyLoss() + # Only keep active parts of the loss + if attention_mask is not None: + active_loss = attention_mask.view(-1) == 1 + active_logits = logits.view(-1, self.num_labels) + active_labels = torch.where( + active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) + ) + loss = loss_fct(active_logits, active_labels) + else: + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + outputs = (loss,) + outputs + + return outputs # (loss), scores, (hidden_states), (attentions) diff --git a/src/transformers/modeling_tf_auto.py b/src/transformers/modeling_tf_auto.py index 23521424fe7..2b0cf73f74c 100644 --- a/src/transformers/modeling_tf_auto.py +++ b/src/transformers/modeling_tf_auto.py @@ -28,6 +28,7 @@ from .configuration_auto import ( ElectraConfig, FlaubertConfig, GPT2Config, + MobileBertConfig, OpenAIGPTConfig, RobertaConfig, T5Config, @@ -88,6 +89,15 @@ from .modeling_tf_flaubert import ( TFFlaubertWithLMHeadModel, ) from .modeling_tf_gpt2 import TFGPT2LMHeadModel, TFGPT2Model +from .modeling_tf_mobilebert import ( + TFMobileBertForMaskedLM, + TFMobileBertForMultipleChoice, + TFMobileBertForPreTraining, + TFMobileBertForQuestionAnswering, + TFMobileBertForSequenceClassification, + TFMobileBertForTokenClassification, + TFMobileBertModel, +) from .modeling_tf_openai import TFOpenAIGPTLMHeadModel, TFOpenAIGPTModel from .modeling_tf_roberta import ( TFRobertaForMaskedLM, @@ -138,6 +148,7 @@ TF_MODEL_MAPPING = OrderedDict( (ElectraConfig, TFElectraModel), (FlaubertConfig, TFFlaubertModel), (GPT2Config, TFGPT2Model), + (MobileBertConfig, TFMobileBertModel), (OpenAIGPTConfig, TFOpenAIGPTModel), (RobertaConfig, TFRobertaModel), (T5Config, TFT5Model), @@ -158,6 +169,7 @@ TF_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict( (ElectraConfig, TFElectraForPreTraining), (FlaubertConfig, TFFlaubertWithLMHeadModel), (GPT2Config, TFGPT2LMHeadModel), + (MobileBertConfig, TFMobileBertForPreTraining), (OpenAIGPTConfig, TFOpenAIGPTLMHeadModel), (RobertaConfig, TFRobertaForMaskedLM), (T5Config, TFT5ForConditionalGeneration), @@ -178,6 +190,7 @@ TF_MODEL_WITH_LM_HEAD_MAPPING = OrderedDict( (ElectraConfig, TFElectraForMaskedLM), (FlaubertConfig, TFFlaubertWithLMHeadModel), (GPT2Config, TFGPT2LMHeadModel), + (MobileBertConfig, TFMobileBertForMaskedLM), (OpenAIGPTConfig, TFOpenAIGPTLMHeadModel), (RobertaConfig, TFRobertaForMaskedLM), (T5Config, TFT5ForConditionalGeneration), @@ -195,6 +208,7 @@ TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict( (CamembertConfig, TFCamembertForMultipleChoice), (DistilBertConfig, TFDistilBertForMultipleChoice), (FlaubertConfig, TFFlaubertForMultipleChoice), + (MobileBertConfig, TFMobileBertForMultipleChoice), (RobertaConfig, TFRobertaForMultipleChoice), (XLMConfig, TFXLMForMultipleChoice), (XLMRobertaConfig, TFXLMRobertaForMultipleChoice), @@ -210,6 +224,7 @@ TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict( (DistilBertConfig, TFDistilBertForQuestionAnswering), (ElectraConfig, TFElectraForQuestionAnswering), (FlaubertConfig, TFFlaubertForQuestionAnsweringSimple), + (MobileBertConfig, TFMobileBertForQuestionAnswering), (RobertaConfig, TFRobertaForQuestionAnswering), (XLMConfig, TFXLMForQuestionAnsweringSimple), (XLMRobertaConfig, TFXLMRobertaForQuestionAnswering), @@ -224,6 +239,7 @@ TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict( (CamembertConfig, TFCamembertForSequenceClassification), (DistilBertConfig, TFDistilBertForSequenceClassification), (FlaubertConfig, TFFlaubertForSequenceClassification), + (MobileBertConfig, TFMobileBertForSequenceClassification), (RobertaConfig, TFRobertaForSequenceClassification), (XLMConfig, TFXLMForSequenceClassification), (XLMRobertaConfig, TFXLMRobertaForSequenceClassification), @@ -239,6 +255,7 @@ TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict( (DistilBertConfig, TFDistilBertForTokenClassification), (ElectraConfig, TFElectraForTokenClassification), (FlaubertConfig, TFFlaubertForTokenClassification), + (MobileBertConfig, TFMobileBertForTokenClassification), (RobertaConfig, TFRobertaForTokenClassification), (XLMConfig, TFXLMForTokenClassification), (XLMRobertaConfig, TFXLMRobertaForTokenClassification), diff --git a/src/transformers/modeling_tf_mobilebert.py b/src/transformers/modeling_tf_mobilebert.py new file mode 100644 index 00000000000..e0e2b1fd343 --- /dev/null +++ b/src/transformers/modeling_tf_mobilebert.py @@ -0,0 +1,1474 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" TF 2.0 MobileBERT model. """ + + +import logging + +import tensorflow as tf + +from . import MobileBertConfig +from .file_utils import MULTIPLE_CHOICE_DUMMY_INPUTS, add_start_docstrings, add_start_docstrings_to_callable +from .modeling_tf_bert import TFBertIntermediate, gelu, gelu_new, swish +from .modeling_tf_utils import ( + TFMultipleChoiceLoss, + TFPreTrainedModel, + TFQuestionAnsweringLoss, + TFSequenceClassificationLoss, + TFTokenClassificationLoss, + cast_bool_to_primitive, + get_initializer, + keras_serializable, + shape_list, +) +from .tokenization_utils import BatchEncoding + + +logger = logging.getLogger(__name__) + + +TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "mobilebert-uncased", + # See all MobileBERT models at https://huggingface.co/models?filter=mobilebert +] + + +def mish(x): + return x * tf.tanh(tf.math.softplus(x)) + + +class TFLayerNorm(tf.keras.layers.LayerNormalization): + def __init__(self, feat_size, *args, **kwargs): + super().__init__(*args, **kwargs) + + +class TFNoNorm(tf.keras.layers.Layer): + def __init__(self, feat_size, epsilon=None, **kwargs): + super().__init__(**kwargs) + self.feat_size = feat_size + + def build(self, input_shape): + self.bias = self.add_weight("bias", shape=[self.feat_size], initializer="zeros") + self.weight = self.add_weight("weight", shape=[self.feat_size], initializer="ones") + + def call(self, inputs: tf.Tensor): + return inputs * self.weight + self.bias + + +ACT2FN = { + "gelu": tf.keras.layers.Activation(gelu), + "relu": tf.keras.activations.relu, + "swish": tf.keras.layers.Activation(swish), + "gelu_new": tf.keras.layers.Activation(gelu_new), +} +NORM2FN = {"layer_norm": TFLayerNorm, "no_norm": TFNoNorm} + + +class TFMobileBertEmbeddings(tf.keras.layers.Layer): + """Construct the embeddings from word, position and token_type embeddings. + """ + + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + self.trigram_input = config.trigram_input + self.embedding_size = config.embedding_size + self.vocab_size = config.vocab_size + self.hidden_size = config.hidden_size + self.initializer_range = config.initializer_range + + self.position_embeddings = tf.keras.layers.Embedding( + config.max_position_embeddings, + config.hidden_size, + embeddings_initializer=get_initializer(self.initializer_range), + name="position_embeddings", + ) + self.token_type_embeddings = tf.keras.layers.Embedding( + config.type_vocab_size, + config.hidden_size, + embeddings_initializer=get_initializer(self.initializer_range), + name="token_type_embeddings", + ) + + self.embedding_transformation = tf.keras.layers.Dense(config.hidden_size, name="embedding_transformation") + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = NORM2FN[config.normalization_type]( + config.hidden_size, epsilon=config.layer_norm_eps, name="LayerNorm" + ) + self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) + + def build(self, input_shape): + """Build shared word embedding layer """ + with tf.name_scope("word_embeddings"): + # Create and initialize weights. The random normal initializer was chosen + # arbitrarily, and works well. + self.word_embeddings = self.add_weight( + "weight", + shape=[self.vocab_size, self.embedding_size], + initializer=get_initializer(self.initializer_range), + ) + super().build(input_shape) + + def call(self, inputs, mode="embedding", training=False): + """Get token embeddings of inputs. + Args: + inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids) + mode: string, a valid value is one of "embedding" and "linear". + Returns: + outputs: (1) If mode == "embedding", output embedding tensor, float32 with + shape [batch_size, length, embedding_size]; (2) mode == "linear", output + linear tensor, float32 with shape [batch_size, length, vocab_size]. + Raises: + ValueError: if mode is not valid. + + Shared weights logic adapted from + https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24 + """ + if mode == "embedding": + return self._embedding(inputs, training=training) + elif mode == "linear": + return self._linear(inputs) + else: + raise ValueError("mode {} is not valid.".format(mode)) + + def _embedding(self, inputs, training=False): + """Applies embedding based on inputs tensor.""" + input_ids, position_ids, token_type_ids, inputs_embeds = inputs + + if input_ids is not None: + input_shape = shape_list(input_ids) + else: + input_shape = shape_list(inputs_embeds)[:-1] + + seq_length = input_shape[1] + if position_ids is None: + position_ids = tf.range(seq_length, dtype=tf.int32)[tf.newaxis, :] + if token_type_ids is None: + token_type_ids = tf.fill(input_shape, 0) + + if inputs_embeds is None: + inputs_embeds = tf.gather(self.word_embeddings, input_ids) + + if self.trigram_input: + # From the paper MobileBERT: a Compact Task-Agnostic BERT for Resource-Limited + # Devices (https://arxiv.org/abs/2004.02984) + # + # The embedding table in BERT models accounts for a substantial proportion of model size. To compress + # the embedding layer, we reduce the embedding dimension to 128 in MobileBERT. + # Then, we apply a 1D convolution with kernel size 3 on the raw token embedding to produce a 512 + # dimensional output. + inputs_embeds = tf.concat( + [ + tf.pad(inputs_embeds[:, 1:], ((0, 0), (0, 1), (0, 0))), + inputs_embeds, + tf.pad(inputs_embeds[:, :-1], ((0, 0), (1, 0), (0, 0))), + ], + axis=2, + ) + + if self.trigram_input or self.embedding_size != self.hidden_size: + inputs_embeds = self.embedding_transformation(inputs_embeds) + + position_embeddings = self.position_embeddings(position_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + position_embeddings + token_type_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings, training=training) + return embeddings + + def _linear(self, inputs): + """Computes logits by running inputs through a linear layer. + Args: + inputs: A float32 tensor with shape [batch_size, length, hidden_size] + Returns: + float32 tensor with shape [batch_size, length, vocab_size]. + """ + batch_size = shape_list(inputs)[0] + length = shape_list(inputs)[1] + + x = tf.reshape(inputs, [-1, self.hidden_size]) + logits = tf.matmul(x, self.word_embeddings, transpose_b=True) + + return tf.reshape(logits, [batch_size, length, self.vocab_size]) + + +class TFMobileBertSelfAttention(tf.keras.layers.Layer): + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads) + ) + + self.num_attention_heads = config.num_attention_heads + assert config.hidden_size % config.num_attention_heads == 0 + self.attention_head_size = int(config.true_hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = tf.keras.layers.Dense( + self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query" + ) + self.key = tf.keras.layers.Dense( + self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key" + ) + self.value = tf.keras.layers.Dense( + self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value" + ) + + self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x, batch_size): + x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size)) + return tf.transpose(x, perm=[0, 2, 1, 3]) + + def call(self, inputs, training=False): + query_tensor, key_tensor, value_tensor, attention_mask, head_mask, output_attentions = inputs + + batch_size = shape_list(attention_mask)[0] + mixed_query_layer = self.query(query_tensor) + mixed_key_layer = self.key(key_tensor) + mixed_value_layer = self.value(value_tensor) + + query_layer = self.transpose_for_scores(mixed_query_layer, batch_size) + key_layer = self.transpose_for_scores(mixed_key_layer, batch_size) + value_layer = self.transpose_for_scores(mixed_value_layer, batch_size) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = tf.matmul( + query_layer, key_layer, transpose_b=True + ) # (batch size, num_heads, seq_len_q, seq_len_k) + dk = tf.cast(shape_list(key_layer)[-1], tf.float32) # scale attention_scores + attention_scores = attention_scores / tf.math.sqrt(dk) + + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in TFBertModel call() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = tf.nn.softmax(attention_scores, axis=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs, training=training) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = tf.matmul(attention_probs, value_layer) + + context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3]) + context_layer = tf.reshape( + context_layer, (batch_size, -1, self.all_head_size) + ) # (batch_size, seq_len_q, all_head_size) + + outputs = ( + (context_layer, attention_probs) if cast_bool_to_primitive(output_attentions) is True else (context_layer,) + ) + + return outputs + + +class TFMobileBertSelfOutput(tf.keras.layers.Layer): + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + self.use_bottleneck = config.use_bottleneck + self.dense = tf.keras.layers.Dense( + config.true_hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + self.LayerNorm = NORM2FN[config.normalization_type]( + config.true_hidden_size, epsilon=config.layer_norm_eps, name="LayerNorm" + ) + if not self.use_bottleneck: + self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) + + def call(self, inputs, training=False): + hidden_states, residual_tensor = inputs + hidden_states = self.dense(hidden_states) + if not self.use_bottleneck: + hidden_states = self.dropout(hidden_states, training=training) + hidden_states = self.LayerNorm(hidden_states + residual_tensor) + return hidden_states + + +class TFMobileBertAttention(tf.keras.layers.Layer): + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + self.self = TFMobileBertSelfAttention(config, name="self") + self.mobilebert_output = TFMobileBertSelfOutput(config, name="output") + + def prune_heads(self, heads): + raise NotImplementedError + + def call(self, inputs, training=False): + query_tensor, key_tensor, value_tensor, layer_input, attention_mask, head_mask, output_attentions = inputs + + self_outputs = self.self( + [query_tensor, key_tensor, value_tensor, attention_mask, head_mask, output_attentions], training=training + ) + attention_output = self.mobilebert_output([self_outputs[0], layer_input], training=training) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class TFMobileBertIntermediate(TFBertIntermediate): + def __init__(self, config, **kwargs): + super().__init__(config, **kwargs) + self.dense = tf.keras.layers.Dense(config.intermediate_size, name="dense") + + +class TFOutputBottleneck(tf.keras.layers.Layer): + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + self.dense = tf.keras.layers.Dense(config.hidden_size, name="dense") + self.LayerNorm = NORM2FN[config.normalization_type]( + config.hidden_size, epsilon=config.layer_norm_eps, name="LayerNorm" + ) + self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) + + def call(self, inputs, training=False): + hidden_states, residual_tensor = inputs + layer_outputs = self.dense(hidden_states) + layer_outputs = self.dropout(layer_outputs, training=training) + layer_outputs = self.LayerNorm(layer_outputs + residual_tensor) + return layer_outputs + + +class TFMobileBertOutput(tf.keras.layers.Layer): + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + self.use_bottleneck = config.use_bottleneck + self.dense = tf.keras.layers.Dense( + config.true_hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + self.LayerNorm = NORM2FN[config.normalization_type]( + config.true_hidden_size, epsilon=config.layer_norm_eps, name="LayerNorm" + ) + if not self.use_bottleneck: + self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) + else: + self.bottleneck = TFOutputBottleneck(config, name="bottleneck") + + def call(self, inputs, training=False): + hidden_states, residual_tensor_1, residual_tensor_2 = inputs + + hidden_states = self.dense(hidden_states) + if not self.use_bottleneck: + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + residual_tensor_1) + else: + hidden_states = self.LayerNorm(hidden_states + residual_tensor_1) + hidden_states = self.bottleneck([hidden_states, residual_tensor_2]) + return hidden_states + + +class TFBottleneckLayer(tf.keras.layers.Layer): + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + self.dense = tf.keras.layers.Dense(config.intra_bottleneck_size, name="dense") + self.LayerNorm = NORM2FN[config.normalization_type]( + config.intra_bottleneck_size, epsilon=config.layer_norm_eps, name="LayerNorm" + ) + + def call(self, inputs): + hidden_states = self.dense(inputs) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class TFBottleneck(tf.keras.layers.Layer): + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + self.key_query_shared_bottleneck = config.key_query_shared_bottleneck + self.use_bottleneck_attention = config.use_bottleneck_attention + self.bottleneck_input = TFBottleneckLayer(config, name="input") + if self.key_query_shared_bottleneck: + self.attention = TFBottleneckLayer(config, name="attention") + + def call(self, hidden_states): + # This method can return three different tuples of values. These different values make use of bottlenecks, + # which are linear layers used to project the hidden states to a lower-dimensional vector, reducing memory + # usage. These linear layer have weights that are learned during training. + # + # If `config.use_bottleneck_attention`, it will return the result of the bottleneck layer four times for the + # key, query, value, and "layer input" to be used by the attention layer. + # This bottleneck is used to project the hidden. This last layer input will be used as a residual tensor + # in the attention self output, after the attention scores have been computed. + # + # If not `config.use_bottleneck_attention` and `config.key_query_shared_bottleneck`, this will return + # four values, three of which have been passed through a bottleneck: the query and key, passed through the same + # bottleneck, and the residual layer to be applied in the attention self output, through another bottleneck. + # + # Finally, in the last case, the values for the query, key and values are the hidden states without bottleneck, + # and the residual layer will be this value passed through a bottleneck. + + bottlenecked_hidden_states = self.bottleneck_input(hidden_states) + if self.use_bottleneck_attention: + return (bottlenecked_hidden_states,) * 4 + elif self.key_query_shared_bottleneck: + shared_attention_input = self.attention(hidden_states) + return (shared_attention_input, shared_attention_input, hidden_states, bottlenecked_hidden_states) + else: + return (hidden_states, hidden_states, hidden_states, bottlenecked_hidden_states) + + +class TFFFNOutput(tf.keras.layers.Layer): + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + self.dense = tf.keras.layers.Dense(config.true_hidden_size, name="dense") + self.LayerNorm = NORM2FN[config.normalization_type]( + config.true_hidden_size, epsilon=config.layer_norm_eps, name="LayerNorm" + ) + + def call(self, hidden_states, residual_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.LayerNorm(hidden_states + residual_tensor) + return hidden_states + + +class TFFFNLayer(tf.keras.layers.Layer): + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + self.intermediate = TFMobileBertIntermediate(config, name="intermediate") + self.mobilebert_output = TFFFNOutput(config, name="output") + + def call(self, hidden_states): + intermediate_output = self.intermediate(hidden_states) + layer_outputs = self.mobilebert_output(intermediate_output, hidden_states) + return layer_outputs + + +class TFMobileBertLayer(tf.keras.layers.Layer): + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + self.use_bottleneck = config.use_bottleneck + self.num_feedforward_networks = config.num_feedforward_networks + + self.attention = TFMobileBertAttention(config, name="attention") + self.intermediate = TFMobileBertIntermediate(config, name="intermediate") + self.mobilebert_output = TFMobileBertOutput(config, name="output") + + if self.use_bottleneck: + self.bottleneck = TFBottleneck(config, name="bottleneck") + if config.num_feedforward_networks > 1: + self.ffn = [ + TFFFNLayer(config, name="ffn.{}".format(i)) for i in range(config.num_feedforward_networks - 1) + ] + + def call(self, inputs, training=False): + hidden_states, attention_mask, head_mask, output_attentions = inputs + + if self.use_bottleneck: + query_tensor, key_tensor, value_tensor, layer_input = self.bottleneck(hidden_states) + else: + query_tensor, key_tensor, value_tensor, layer_input = [hidden_states] * 4 + + attention_outputs = self.attention( + [query_tensor, key_tensor, value_tensor, layer_input, attention_mask, head_mask, output_attentions], + training=training, + ) + + attention_output = attention_outputs[0] + s = (attention_output,) + + if self.num_feedforward_networks != 1: + for i, ffn_module in enumerate(self.ffn): + attention_output = ffn_module(attention_output) + s += (attention_output,) + + intermediate_output = self.intermediate(attention_output) + layer_output = self.mobilebert_output( + [intermediate_output, attention_output, hidden_states], training=training + ) + outputs = ( + (layer_output,) + + attention_outputs[1:] + + (0, query_tensor, key_tensor, value_tensor, layer_input, attention_output, intermediate_output) + + s + ) # add attentions if we output them + return outputs + + +class TFMobileBertEncoder(tf.keras.layers.Layer): + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + self.output_hidden_states = config.output_hidden_states + self.layer = [TFMobileBertLayer(config, name="layer_._{}".format(i)) for i in range(config.num_hidden_layers)] + + def call(self, inputs, training=False): + hidden_states, attention_mask, head_mask, output_attentions = inputs + + all_hidden_states = () + all_attentions = () + for i, layer_module in enumerate(self.layer): + if self.output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_outputs = layer_module( + [hidden_states, attention_mask, head_mask[i], output_attentions], training=training + ) + hidden_states = layer_outputs[0] + + if cast_bool_to_primitive(output_attentions) is True: + all_attentions = all_attentions + (layer_outputs[1],) + + # Add last layer + if self.output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + outputs = (hidden_states,) + if self.output_hidden_states: + outputs = outputs + (all_hidden_states,) + if cast_bool_to_primitive(output_attentions) is True: + outputs = outputs + (all_attentions,) + return outputs # outputs, (hidden states), (attentions) + + +class TFMobileBertPooler(tf.keras.layers.Layer): + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + self.do_activate = config.classifier_activation + if self.do_activate: + self.dense = tf.keras.layers.Dense( + config.hidden_size, + kernel_initializer=get_initializer(config.initializer_range), + activation="tanh", + name="dense", + ) + + def call(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + if not self.do_activate: + return first_token_tensor + else: + pooled_output = self.dense(first_token_tensor) + return pooled_output + + +class TFMobileBertPredictionHeadTransform(tf.keras.layers.Layer): + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + self.dense = tf.keras.layers.Dense( + config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = NORM2FN["layer_norm"](config.hidden_size, epsilon=config.layer_norm_eps, name="LayerNorm") + + def call(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class TFMobileBertLMPredictionHead(tf.keras.layers.Layer): + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + self.transform = TFMobileBertPredictionHeadTransform(config, name="transform") + self.vocab_size = config.vocab_size + self.config = config + + def build(self, input_shape): + self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias") + self.dense = self.add_weight( + shape=(self.config.hidden_size - self.config.embedding_size, self.vocab_size), + initializer="zeros", + trainable=True, + name="dense/weight", + ) + self.decoder = self.add_weight( + shape=(self.config.vocab_size, self.config.embedding_size), + initializer="zeros", + trainable=True, + name="decoder/weight", + ) + super().build(input_shape) + + def call(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = tf.matmul(hidden_states, tf.concat([tf.transpose(self.decoder), self.dense], axis=0)) + hidden_states = hidden_states + self.bias + return hidden_states + + +class TFMobileBertMLMHead(tf.keras.layers.Layer): + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + self.predictions = TFMobileBertLMPredictionHead(config, name="predictions") + + def call(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class TFMobileBertPreTrainingHeads(tf.keras.layers.Layer): + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + self.predictions = TFMobileBertLMPredictionHead(config, name="predictions") + self.seq_relationship = tf.keras.layers.Dense(2, name="seq_relationship") + + def call(self, inputs): + sequence_output, pooled_output = inputs + prediction_scores = self.predictions(sequence_output) + seq_relationship_score = self.seq_relationship(pooled_output) + return prediction_scores, seq_relationship_score + + +@keras_serializable +class TFMobileBertMainLayer(tf.keras.layers.Layer): + config_class = MobileBertConfig + + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + self.num_hidden_layers = config.num_hidden_layers + self.output_attentions = config.output_attentions + + self.embeddings = TFMobileBertEmbeddings(config, name="embeddings") + self.encoder = TFMobileBertEncoder(config, name="encoder") + self.pooler = TFMobileBertPooler(config, name="pooler") + + def get_input_embeddings(self): + return self.embeddings + + def _resize_token_embeddings(self, new_num_tokens): + raise NotImplementedError + + def _prune_heads(self, heads_to_prune): + """ Prunes heads of the model. + heads_to_prune: dict of {layer_num: list of heads to prune in this layer} + See base class PreTrainedModel + """ + raise NotImplementedError + + def call( + self, + inputs, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + output_attentions=None, + training=False, + ): + if isinstance(inputs, (tuple, list)): + input_ids = inputs[0] + attention_mask = inputs[1] if len(inputs) > 1 else attention_mask + token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids + position_ids = inputs[3] if len(inputs) > 3 else position_ids + head_mask = inputs[4] if len(inputs) > 4 else head_mask + inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds + output_attentions = inputs[6] if len(inputs) > 6 else output_attentions + assert len(inputs) <= 7, "Too many inputs." + elif isinstance(inputs, (dict, BatchEncoding)): + input_ids = inputs.get("input_ids") + attention_mask = inputs.get("attention_mask", attention_mask) + token_type_ids = inputs.get("token_type_ids", token_type_ids) + position_ids = inputs.get("position_ids", position_ids) + head_mask = inputs.get("head_mask", head_mask) + inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) + output_attentions = inputs.get("output_attentions", output_attentions) + assert len(inputs) <= 7, "Too many inputs." + else: + input_ids = inputs + + output_attentions = output_attentions if output_attentions is not None else self.output_attentions + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = shape_list(input_ids) + elif inputs_embeds is not None: + input_shape = shape_list(inputs_embeds)[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if attention_mask is None: + attention_mask = tf.fill(input_shape, 1) + if token_type_ids is None: + token_type_ids = tf.fill(input_shape, 0) + + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + + extended_attention_mask = tf.cast(extended_attention_mask, tf.float32) + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + if head_mask is not None: + raise NotImplementedError + else: + head_mask = [None] * self.num_hidden_layers + # head_mask = tf.constant([0] * self.num_hidden_layers) + + embedding_output = self.embeddings([input_ids, position_ids, token_type_ids, inputs_embeds], training=training) + encoder_outputs = self.encoder( + [embedding_output, extended_attention_mask, head_mask, output_attentions], training=training + ) + + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) + + outputs = (sequence_output, pooled_output,) + encoder_outputs[ + 1: + ] # add hidden_states and attentions if they are here + return outputs # sequence_output, pooled_output, (hidden_states), (attentions) + + +class TFMobileBertPreTrainedModel(TFPreTrainedModel): + """ An abstract class to handle weights initialization and + a simple interface for downloading and loading pretrained models. + """ + + config_class = MobileBertConfig + base_model_prefix = "mobilebert" + + +MOBILEBERT_START_DOCSTRING = r""" + This model is a `tf.keras.Model `__ sub-class. + Use it as a regular TF 2.0 Keras Model and + refer to the TF 2.0 documentation for all matter related to general usage and behavior. + + .. note:: + + TF 2.0 models accepts two formats as inputs: + + - having all inputs as keyword arguments (like PyTorch models), or + - having all inputs as a list, tuple or dict in the first positional arguments. + + This second option is useful when using :obj:`tf.keras.Model.fit()` method which currently requires having + all the tensors in the first argument of the model call function: :obj:`model(inputs)`. + + If you choose this second option, there are three possibilities you can use to gather all the input Tensors + in the first positional argument : + + - a single Tensor with input_ids only and nothing else: :obj:`model(inputs_ids)` + - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: + :obj:`model([input_ids, attention_mask])` or :obj:`model([input_ids, attention_mask, token_type_ids])` + - a dictionary with one or several input Tensors associated to the input names given in the docstring: + :obj:`model({'input_ids': input_ids, 'token_type_ids': token_type_ids})` + + Parameters: + config (:class:`~transformers.MobileBertConfig`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the configuration. + Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. +""" + +MOBILEBERT_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`{0}`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using :class:`transformers.MobileBertTokenizer`. + See :func:`transformers.PreTrainedTokenizer.encode` and + :func:`transformers.PreTrainedTokenizer.encode_plus` for details. + + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): + Mask to avoid performing attention on padding token indices. + Mask values selected in ``[0, 1]``: + ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. + + `What are attention masks? <../glossary.html#attention-mask>`__ + token_type_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): + Segment token indices to indicate first and second portions of the inputs. + Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1`` + corresponds to a `sentence B` token + + `What are token type IDs? <../glossary.html#token-type-ids>`__ + position_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): + Indices of positions of each input sequence tokens in the position embeddings. + Selected in the range ``[0, config.max_position_embeddings - 1]``. + + `What are position IDs? <../glossary.html#position-ids>`__ + head_mask (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`): + Mask to nullify selected heads of the self-attention modules. + Mask values selected in ``[0, 1]``: + :obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**. + inputs_embeds (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, embedding_dim)`, `optional`, defaults to :obj:`None`): + Optionally, instead of passing :obj:`input_ids` you can to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + training (:obj:`boolean`, `optional`, defaults to :obj:`False`): + Whether to activate dropout modules (if set to :obj:`True`) during training or to de-activate them + (if set to :obj:`False`) for evaluation. +""" + + +@add_start_docstrings( + "The bare MobileBert Model transformer outputing raw hidden-states without any specific head on top.", + MOBILEBERT_START_DOCSTRING, +) +class TFMobileBertModel(TFMobileBertPreTrainedModel): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.mobilebert = TFMobileBertMainLayer(config, name="mobilebert") + + @add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) + def call(self, inputs, **kwargs): + r""" + Returns: + :obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.MobileBertConfig`) and inputs: + last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (:obj:`tf.Tensor` of shape :obj:`(batch_size, hidden_size)`): + Last layer hidden-state of the first token of the sequence (classification token) + further processed by a Linear layer and a Tanh activation function. The Linear + layer weights are trained from the next sentence prediction (classification) + objective during the original Bert pretraining. This output is usually *not* a good summary + of the semantic content of the input, you're often better with averaging or pooling + the sequence of hidden-states for the whole input sequence. + hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`): + tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``): + tuple of :obj:`tf.Tensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + + + Examples:: + + import tensorflow as tf + from transformers import MobileBertTokenizer, TFMobileBertModel + + tokenizer = MobileBertTokenizer.from_pretrained('mobilebert-uncased') + model = TFMobileBertModel.from_pretrained('mobilebert-uncased') + input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1 + outputs = model(input_ids) + last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple + """ + outputs = self.mobilebert(inputs, **kwargs) + return outputs + + +@add_start_docstrings( + """MobileBert Model with two heads on top as done during the pre-training: + a `masked language modeling` head and a `next sentence prediction (classification)` head. """, + MOBILEBERT_START_DOCSTRING, +) +class TFMobileBertForPreTraining(TFMobileBertPreTrainedModel): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.mobilebert = TFMobileBertMainLayer(config, name="mobilebert") + self.cls = TFMobileBertPreTrainingHeads(config, name="cls") + + def get_output_embeddings(self): + return self.mobilebert.embeddings + + @add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) + def call(self, inputs, **kwargs): + r""" + Return: + :obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.MobileBertConfig`) and inputs: + prediction_scores (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + seq_relationship_scores (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, 2)`): + Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax). + hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`): + tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``): + tuple of :obj:`tf.Tensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`: + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + + Examples:: + + import tensorflow as tf + from transformers import MobileBertTokenizer, TFMobileBertForPreTraining + + tokenizer = MobileBertTokenizer.from_pretrained('mobilebert-uncased') + model = TFMobileBertForPreTraining.from_pretrained('mobilebert-uncased') + input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1 + outputs = model(input_ids) + prediction_scores, seq_relationship_scores = outputs[:2] + + """ + outputs = self.mobilebert(inputs, **kwargs) + + sequence_output, pooled_output = outputs[:2] + prediction_scores, seq_relationship_score = self.cls([sequence_output, pooled_output]) + outputs = (prediction_scores, seq_relationship_score,) + outputs[ + 2: + ] # add hidden states and attention if they are here + + return outputs # prediction_scores, seq_relationship_score, (hidden_states), (attentions) + + +@add_start_docstrings("""MobileBert Model with a `language modeling` head on top. """, MOBILEBERT_START_DOCSTRING) +class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.mobilebert = TFMobileBertMainLayer(config, name="mobilebert") + self.mlm = TFMobileBertMLMHead(config, name="mlm___cls") + + def get_output_embeddings(self): + return self.mobilebert.embeddings + + @add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) + def call(self, inputs, **kwargs): + r""" + Return: + :obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.MobileBertConfig`) and inputs: + prediction_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`): + tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): + tuple of :obj:`tf.Tensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`: + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + + Examples:: + + import tensorflow as tf + from transformers import MobileBertTokenizer, TFMobileBertForMaskedLM + + tokenizer = MobileBertTokenizer.from_pretrained('mobilebert-uncased') + model = TFMobileBertForMaskedLM.from_pretrained('mobilebert-uncased') + input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1 + outputs = model(input_ids) + prediction_scores = outputs[0] + + """ + outputs = self.mobilebert(inputs, **kwargs) + + sequence_output = outputs[0] + prediction_scores = self.mlm(sequence_output, training=kwargs.get("training", False)) + + outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here + + return outputs # prediction_scores, (hidden_states), (attentions) + + +class TFMobileBertOnlyNSPHead(tf.keras.layers.Layer): + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + self.seq_relationship = tf.keras.layers.Dense(2, name="seq_relationship") + + def call(self, pooled_output): + seq_relationship_score = self.seq_relationship(pooled_output) + return seq_relationship_score + + +@add_start_docstrings( + """MobileBert Model with a `next sentence prediction (classification)` head on top. """, + MOBILEBERT_START_DOCSTRING, +) +class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.mobilebert = TFMobileBertMainLayer(config, name="mobilebert") + self.cls = TFMobileBertOnlyNSPHead(config, name="cls") + + @add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) + def call(self, inputs, **kwargs): + r""" + Return: + :obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.MobileBertConfig`) and inputs: + seq_relationship_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, 2)`) + Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax). + hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`): + tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): + tuple of :obj:`tf.Tensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`: + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + + Examples:: + + import tensorflow as tf + from transformers import MobileBertTokenizer, TFMobileBertForNextSentencePrediction + + tokenizer = MobileBertTokenizer.from_pretrained('mobilebert-uncased') + model = TFMobileBertForNextSentencePrediction.from_pretrained('mobilebert-uncased') + + prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." + next_sentence = "The sky is blue due to the shorter wavelength of blue light." + encoding = tokenizer.encode_plus(prompt, next_sentence, return_tensors='tf') + + logits = model(encoding['input_ids'], token_type_ids=encoding['token_type_ids'])[0] + assert logits[0][0] < logits[0][1] # the next sentence was random + """ + outputs = self.mobilebert(inputs, **kwargs) + + pooled_output = outputs[1] + seq_relationship_score = self.cls(pooled_output) + + outputs = (seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here + + return outputs # seq_relationship_score, (hidden_states), (attentions) + + +@add_start_docstrings( + """MobileBert Model transformer with a sequence classification/regression head on top (a linear layer on top of + the pooled output) e.g. for GLUE tasks. """, + MOBILEBERT_START_DOCSTRING, +) +class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSequenceClassificationLoss): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.num_labels = config.num_labels + + self.mobilebert = TFMobileBertMainLayer(config, name="mobilebert") + self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) + self.classifier = tf.keras.layers.Dense( + config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" + ) + + @add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING) + def call( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + training=False, + ): + r""" + labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): + Labels for computing the sequence classification/regression loss. + Indices should be in :obj:`[0, ..., config.num_labels - 1]`. + If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), + If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Return: + :obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.MobileBertConfig`) and inputs: + logits (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`): + tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``): + tuple of :obj:`tf.Tensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`: + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + + Examples:: + + import tensorflow as tf + from transformers import MobileBertTokenizer, TFBMobileBertForSequenceClassification + + tokenizer = MobileBertTokenizer.from_pretrained('mobilebert-uncased') + model = TFMobileBertForSequenceClassification.from_pretrained('mobilebert-uncased') + input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1 + labels = tf.reshape(tf.constant(1), (-1, 1)) # Batch size 1 + outputs = model(input_ids, labels=labels) + loss, logits = outputs[:2] + + """ + + outputs = self.mobilebert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + training=training, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output, training=training) + logits = self.classifier(pooled_output) + + outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here + + if labels is not None: + loss = self.compute_loss(labels, logits) + outputs = (loss,) + outputs + + return outputs # (loss), logits, (hidden_states), (attentions) + + +@add_start_docstrings( + """MobileBert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of + the hidden-states output to compute `span start logits` and `span end logits`). """, + MOBILEBERT_START_DOCSTRING, +) +class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAnsweringLoss): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.num_labels = config.num_labels + + self.mobilebert = TFMobileBertMainLayer(config, name="mobilebert") + self.qa_outputs = tf.keras.layers.Dense( + config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" + ) + + @add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING) + def call( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + start_positions=None, + end_positions=None, + cls_index=None, + p_mask=None, + is_impossible=None, + output_attentions=None, + training=False, + ): + r""" + start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). + Position outside of the sequence are not taken into account for computing the loss. + end_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). + Position outside of the sequence are not taken into account for computing the loss. + + Return: + :obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.MobileBertConfig`) and inputs: + start_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length,)`): + Span-start scores (before SoftMax). + end_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length,)`): + Span-end scores (before SoftMax). + hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`): + tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``): + tuple of :obj:`tf.Tensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`: + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + + Examples:: + + import tensorflow as tf + from transformers import MobileBertTokenizer, TFMobileBertForQuestionAnswering + + tokenizer = MobileBertTokenizer.from_pretrained('mobilebert-uncased') + model = TFMobileBertForQuestionAnswering.from_pretrained('mobilebert-uncased') # Not a fine-tuned model! Load a fine-tuned model to obtain coherent results. + question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" + input_dict = tokenizer.encode_plus(question, text, return_tensors='tf') + start_scores, end_scores = model(input_dict) + + all_tokens = tokenizer.convert_ids_to_tokens(input_dict["input_ids"].numpy()[0]) + answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1]) + assert answer == "a nice puppet" + + """ + outputs = self.mobilebert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + training=training, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = tf.split(logits, 2, axis=-1) + start_logits = tf.squeeze(start_logits, axis=-1) + end_logits = tf.squeeze(end_logits, axis=-1) + + outputs = (start_logits, end_logits,) + outputs[2:] + + if start_positions is not None and end_positions is not None: + labels = {"start_position": start_positions} + labels["end_position"] = end_positions + loss = self.compute_loss(labels, outputs[:2]) + outputs = (loss,) + outputs + + return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions) + + +@add_start_docstrings( + """MobileBert Model with a multiple choice classification head on top (a linear layer on top of + the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """, + MOBILEBERT_START_DOCSTRING, +) +class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoiceLoss): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.mobilebert = TFMobileBertMainLayer(config, name="mobilebert") + self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) + self.classifier = tf.keras.layers.Dense( + 1, kernel_initializer=get_initializer(config.initializer_range), name="classifier" + ) + + @property + def dummy_inputs(self): + """ Dummy inputs to build the network. + + Returns: + tf.Tensor with dummy inputs + """ + return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)} + + @add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)")) + def call( + self, + inputs, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + training=False, + ): + r""" + labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): + Labels for computing the multiple choice classification loss. + Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension + of the input tensors. (see `input_ids` above) + + Return: + :obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.MobileBertConfig`) and inputs: + classification_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices)`: + `num_choices` is the size of the second dimension of the input tensors. (see `input_ids` above). + + Classification scores (before SoftMax). + hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`): + tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): + tuple of :obj:`tf.Tensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`: + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + + Examples:: + + import tensorflow as tf + from transformers import MobileBertTokenizer, TFMobileBertForMultipleChoice + + tokenizer = MobileBertTokenizer.from_pretrained('mobilebert-uncased') + model = TFMobileBertForMultipleChoice.from_pretrained('mobilebert-uncased') + choices = ["Hello, my dog is cute", "Hello, my cat is amazing"] + + input_ids = tf.constant([tokenizer.encode(s, add_special_tokens=True) for s in choices])[None, :] # Batch size 1, 2 choices + labels = tf.reshape(tf.constant(1), (-1, 1)) + outputs = model(input_ids, labels=labels) + + loss, classification_scores = outputs[:2] + + """ + if isinstance(inputs, (tuple, list)): + input_ids = inputs[0] + attention_mask = inputs[1] if len(inputs) > 1 else attention_mask + token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids + position_ids = inputs[3] if len(inputs) > 3 else position_ids + head_mask = inputs[4] if len(inputs) > 4 else head_mask + inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds + output_attentions = inputs[6] if len(inputs) > 6 else output_attentions + assert len(inputs) <= 7, "Too many inputs." + elif isinstance(inputs, (dict, BatchEncoding)): + input_ids = inputs.get("input_ids") + attention_mask = inputs.get("attention_mask", attention_mask) + token_type_ids = inputs.get("token_type_ids", token_type_ids) + position_ids = inputs.get("position_ids", position_ids) + head_mask = inputs.get("head_mask", head_mask) + inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) + output_attentions = inputs.get("output_attentions", output_attentions) + assert len(inputs) <= 7, "Too many inputs." + else: + input_ids = inputs + + if input_ids is not None: + num_choices = shape_list(input_ids)[1] + seq_length = shape_list(input_ids)[2] + else: + num_choices = shape_list(inputs_embeds)[1] + seq_length = shape_list(inputs_embeds)[2] + + flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None + flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None + flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None + flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None + flat_inputs_embeds = ( + tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3])) + if inputs_embeds is not None + else None + ) + + flat_inputs = [ + flat_input_ids, + flat_attention_mask, + flat_token_type_ids, + flat_position_ids, + head_mask, + flat_inputs_embeds, + output_attentions, + ] + + outputs = self.mobilebert(flat_inputs, training=training) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output, training=training) + logits = self.classifier(pooled_output) + reshaped_logits = tf.reshape(logits, (-1, num_choices)) + + outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here + + if labels is not None: + loss = self.compute_loss(labels, reshaped_logits) + outputs = (loss,) + outputs + + return outputs # (loss), reshaped_logits, (hidden_states), (attentions) + + +@add_start_docstrings( + """MobileBert Model with a token classification head on top (a linear layer on top of + the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """, + MOBILEBERT_START_DOCSTRING, +) +class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenClassificationLoss): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.num_labels = config.num_labels + + self.mobilebert = TFMobileBertMainLayer(config, name="mobilebert") + self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) + self.classifier = tf.keras.layers.Dense( + config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" + ) + + @add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING) + def call( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + training=False, + ): + r""" + labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + Labels for computing the token classification loss. + Indices should be in ``[0, ..., config.num_labels - 1]``. + + Return: + :obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.MobileBertConfig`) and inputs: + scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`): + Classification scores (before SoftMax). + hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`): + tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): + tuple of :obj:`tf.Tensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`: + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + + Examples:: + + import tensorflow as tf + from transformers import MobileBertTokenizer, TFMobileBertForTokenClassification + + tokenizer = MobileBertTokenizer.from_pretrained('mobilebert-uncased') + model = TFMobileBertForTokenClassification.from_pretrained('mobilebert-uncased') + input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1 + labels = tf.reshape(tf.constant([1] * tf.size(input_ids).numpy()), (-1, tf.size(input_ids))) # Batch size 1 + outputs = model(input_ids, labels=labels) + loss, scores = outputs[:2] + + """ + outputs = self.mobilebert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + training=training, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output, training=training) + logits = self.classifier(sequence_output) + + outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here + + if labels is not None: + loss = self.compute_loss(labels, logits) + outputs = (loss,) + outputs + + return outputs # (loss), logits, (hidden_states), (attentions) diff --git a/src/transformers/tokenization_auto.py b/src/transformers/tokenization_auto.py index 39f4fa3dcb4..80152e85a02 100644 --- a/src/transformers/tokenization_auto.py +++ b/src/transformers/tokenization_auto.py @@ -18,6 +18,8 @@ import logging from collections import OrderedDict +from transformers.configuration_mobilebert import MobileBertConfig + from .configuration_auto import ( AlbertConfig, AutoConfig, @@ -55,6 +57,7 @@ from .tokenization_flaubert import FlaubertTokenizer from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast from .tokenization_longformer import LongformerTokenizer from .tokenization_marian import MarianTokenizer +from .tokenization_mobilebert import MobileBertTokenizer, MobileBertTokenizerFast from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast from .tokenization_reformer import ReformerTokenizer from .tokenization_retribert import RetriBertTokenizer, RetriBertTokenizerFast @@ -73,6 +76,7 @@ TOKENIZER_MAPPING = OrderedDict( [ (RetriBertConfig, (RetriBertTokenizer, RetriBertTokenizerFast)), (T5Config, (T5Tokenizer, None)), + (MobileBertConfig, (MobileBertTokenizer, MobileBertTokenizerFast)), (DistilBertConfig, (DistilBertTokenizer, DistilBertTokenizerFast)), (AlbertConfig, (AlbertTokenizer, None)), (CamembertConfig, (CamembertTokenizer, None)), diff --git a/src/transformers/tokenization_mobilebert.py b/src/transformers/tokenization_mobilebert.py new file mode 100644 index 00000000000..5b524d97872 --- /dev/null +++ b/src/transformers/tokenization_mobilebert.py @@ -0,0 +1,69 @@ +# coding=utf-8 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for MobileBERT.""" + + +import logging + +from .tokenization_bert import BertTokenizer, BertTokenizerFast + + +logger = logging.getLogger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} + +PRETRAINED_VOCAB_FILES_MAP = { + "vocab_file": { + "mobilebert-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/google/mobilebert-uncased/vocab.txt" + } +} + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {} + + +PRETRAINED_INIT_CONFIGURATION = {} + + +class MobileBertTokenizer(BertTokenizer): + r""" + Constructs a MobileBertTokenizer. + + :class:`~transformers.MobileBertTokenizer is identical to :class:`~transformers.BertTokenizer` and runs end-to-end + tokenization: punctuation splitting + wordpiece. + + Refer to superclass :class:`~transformers.BertTokenizer` for usage examples and documentation concerning + parameters. + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION + + +class MobileBertTokenizerFast(BertTokenizerFast): + r""" + Constructs a "Fast" MobileBertTokenizer (backed by HuggingFace's `tokenizers` library). + + :class:`~transformers.MobileBertTokenizerFast` is identical to :class:`~transformers.BertTokenizerFast` and runs end-to-end + tokenization: punctuation splitting + wordpiece. + + Refer to superclass :class:`~transformers.BertTokenizerFast` for usage examples and documentation concerning + parameters. + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION diff --git a/tests/test_modeling_mobilebert.py b/tests/test_modeling_mobilebert.py new file mode 100644 index 00000000000..4d2c934d9d1 --- /dev/null +++ b/tests/test_modeling_mobilebert.py @@ -0,0 +1,499 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +from transformers import is_torch_available + +from .test_configuration_common import ConfigTester +from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor +from .utils import require_torch, slow, torch_device + + +if is_torch_available(): + import torch + from transformers import ( + MobileBertConfig, + MobileBertModel, + MobileBertForMaskedLM, + MobileBertForNextSentencePrediction, + MobileBertForPreTraining, + MobileBertForQuestionAnswering, + MobileBertForSequenceClassification, + MobileBertForTokenClassification, + MobileBertForMultipleChoice, + ) + from transformers.modeling_mobilebert import MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST + + +class MobileBertModelTester: + def __init__( + self, + parent, + batch_size=13, + seq_length=7, + is_training=True, + use_input_mask=True, + use_token_type_ids=True, + use_labels=True, + vocab_size=99, + hidden_size=64, + embedding_size=32, + num_hidden_layers=5, + num_attention_heads=4, + intermediate_size=37, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=16, + type_sequence_label_size=2, + initializer_range=0.02, + num_labels=3, + num_choices=4, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_input_mask = use_input_mask + self.use_token_type_ids = use_token_type_ids + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.embedding_size = embedding_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.num_labels = num_labels + self.num_choices = num_choices + self.scope = scope + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) + + token_type_ids = None + if self.use_token_type_ids: + token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) + + sequence_labels = None + token_labels = None + choice_labels = None + if self.use_labels: + sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) + choice_labels = ids_tensor([self.batch_size], self.num_choices) + + config = MobileBertConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + embedding_size=self.embedding_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + is_decoder=False, + initializer_range=self.initializer_range, + ) + + return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + + def prepare_config_and_inputs_for_decoder(self): + ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = self.prepare_config_and_inputs() + + config.is_decoder = True + encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size]) + encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) + + return ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ) + + def check_loss_output(self, result): + self.parent.assertListEqual(list(result["loss"].size()), []) + + def create_and_check_mobilebert_model( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = MobileBertModel(config=config) + model.to(torch_device) + model.eval() + sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids) + sequence_output, pooled_output = model(input_ids, token_type_ids=token_type_ids) + sequence_output, pooled_output = model(input_ids) + + result = { + "sequence_output": sequence_output, + "pooled_output": pooled_output, + } + self.parent.assertListEqual( + list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size] + ) + self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size]) + + def create_and_check_mobilebert_model_as_decoder( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + model = MobileBertModel(config) + model.to(torch_device) + model.eval() + sequence_output, pooled_output = model( + input_ids, + attention_mask=input_mask, + token_type_ids=token_type_ids, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + ) + sequence_output, pooled_output = model( + input_ids, + attention_mask=input_mask, + token_type_ids=token_type_ids, + encoder_hidden_states=encoder_hidden_states, + ) + sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids) + + result = { + "sequence_output": sequence_output, + "pooled_output": pooled_output, + } + self.parent.assertListEqual( + list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size] + ) + self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size]) + + def create_and_check_mobilebert_for_masked_lm( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = MobileBertForMaskedLM(config=config) + model.to(torch_device) + model.eval() + loss, prediction_scores = model( + input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels + ) + result = { + "loss": loss, + "prediction_scores": prediction_scores, + } + self.parent.assertListEqual( + list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size] + ) + self.check_loss_output(result) + + def create_and_check_mobilebert_for_next_sequence_prediction( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = MobileBertForNextSentencePrediction(config=config) + model.to(torch_device) + model.eval() + loss, seq_relationship_score = model( + input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, next_sentence_label=sequence_labels, + ) + result = { + "loss": loss, + "seq_relationship_score": seq_relationship_score, + } + self.parent.assertListEqual(list(result["seq_relationship_score"].size()), [self.batch_size, 2]) + self.check_loss_output(result) + + def create_and_check_mobilebert_for_pretraining( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = MobileBertForPreTraining(config=config) + model.to(torch_device) + model.eval() + loss, prediction_scores, seq_relationship_score = model( + input_ids, + attention_mask=input_mask, + token_type_ids=token_type_ids, + labels=token_labels, + next_sentence_label=sequence_labels, + ) + result = { + "loss": loss, + "prediction_scores": prediction_scores, + "seq_relationship_score": seq_relationship_score, + } + self.parent.assertListEqual( + list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size] + ) + self.parent.assertListEqual(list(result["seq_relationship_score"].size()), [self.batch_size, 2]) + self.check_loss_output(result) + + def create_and_check_mobilebert_for_question_answering( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = MobileBertForQuestionAnswering(config=config) + model.to(torch_device) + model.eval() + loss, start_logits, end_logits = model( + input_ids, + attention_mask=input_mask, + token_type_ids=token_type_ids, + start_positions=sequence_labels, + end_positions=sequence_labels, + ) + result = { + "loss": loss, + "start_logits": start_logits, + "end_logits": end_logits, + } + self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length]) + self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length]) + self.check_loss_output(result) + + def create_and_check_mobilebert_for_sequence_classification( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + config.num_labels = self.num_labels + model = MobileBertForSequenceClassification(config) + model.to(torch_device) + model.eval() + loss, logits = model( + input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels + ) + result = { + "loss": loss, + "logits": logits, + } + self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels]) + self.check_loss_output(result) + + def create_and_check_mobilebert_for_token_classification( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + config.num_labels = self.num_labels + model = MobileBertForTokenClassification(config=config) + model.to(torch_device) + model.eval() + loss, logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) + result = { + "loss": loss, + "logits": logits, + } + self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels]) + self.check_loss_output(result) + + def create_and_check_mobilebert_for_multiple_choice( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + config.num_choices = self.num_choices + model = MobileBertForMultipleChoice(config=config) + model.to(torch_device) + model.eval() + multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() + multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() + multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() + loss, logits = model( + multiple_choice_inputs_ids, + attention_mask=multiple_choice_input_mask, + token_type_ids=multiple_choice_token_type_ids, + labels=choice_labels, + ) + result = { + "loss": loss, + "logits": logits, + } + self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_choices]) + self.check_loss_output(result) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = config_and_inputs + inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask} + return config, inputs_dict + + +@require_torch +class MobileBertModelTest(ModelTesterMixin, unittest.TestCase): + + all_model_classes = ( + ( + MobileBertModel, + MobileBertForMaskedLM, + MobileBertForMultipleChoice, + MobileBertForNextSentencePrediction, + MobileBertForPreTraining, + MobileBertForQuestionAnswering, + MobileBertForSequenceClassification, + MobileBertForTokenClassification, + ) + if is_torch_available() + else () + ) + + def setUp(self): + self.model_tester = MobileBertModelTester(self) + self.config_tester = ConfigTester(self, config_class=MobileBertConfig, hidden_size=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_mobilebert_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_mobilebert_model(*config_and_inputs) + + def test_mobilebert_model_as_decoder(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() + self.model_tester.create_and_check_mobilebert_model_as_decoder(*config_and_inputs) + + def test_mobilebert_model_as_decoder_with_default_input_mask(self): + # This regression test was failing with PyTorch < 1.3 + ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ) = self.model_tester.prepare_config_and_inputs_for_decoder() + + input_mask = None + + self.model_tester.create_and_check_mobilebert_model_as_decoder( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ) + + def test_for_masked_lm(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_mobilebert_for_masked_lm(*config_and_inputs) + + def test_for_multiple_choice(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_mobilebert_for_multiple_choice(*config_and_inputs) + + def test_for_next_sequence_prediction(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_mobilebert_for_next_sequence_prediction(*config_and_inputs) + + def test_for_pretraining(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_mobilebert_for_pretraining(*config_and_inputs) + + def test_for_question_answering(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_mobilebert_for_question_answering(*config_and_inputs) + + def test_for_sequence_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_mobilebert_for_sequence_classification(*config_and_inputs) + + def test_for_token_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_mobilebert_for_token_classification(*config_and_inputs) + + @slow + def test_model_from_pretrained(self): + for model_name in MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: + model = MobileBertModel.from_pretrained(model_name) + self.assertIsNotNone(model) + + +def _long_tensor(tok_lst): + return torch.tensor(tok_lst, dtype=torch.long, device=torch_device,) + + +TOLERANCE = 1e-3 + + +@require_torch +class MobileBertModelIntegrationTests(unittest.TestCase): + @slow + def test_inference_no_head(self): + model = MobileBertModel.from_pretrained("google/mobilebert-uncased").to(torch_device) + input_ids = _long_tensor([[101, 7110, 1005, 1056, 2023, 11333, 17413, 1029, 102]]) + with torch.no_grad(): + output = model(input_ids)[0] + expected_shape = torch.Size((1, 9, 512)) + self.assertEqual(output.shape, expected_shape) + expected_slice = torch.tensor( + [ + [ + [-2.4736526e07, 8.2691656e04, 1.6521838e05], + [-5.7541704e-01, 3.9056022e00, 4.4011507e00], + [2.6047359e00, 1.5677652e00, -1.7324188e-01], + ] + ], + device=torch_device, + ) + + # MobileBERT results range from 10e0 to 10e8. Even a 0.0000001% difference with a value of 10e8 results in a + # ~1 difference, it's therefore not a good idea to measure using addition. + # Here, we instead divide the expected result with the result in order to obtain ~1. We then check that the + # result is held between bounds: 1 - TOLERANCE < expected_result / result < 1 + TOLERANCE + lower_bound = torch.all((expected_slice / output[..., :3, :3]) >= 1 - TOLERANCE) + upper_bound = torch.all((expected_slice / output[..., :3, :3]) <= 1 + TOLERANCE) + + self.assertTrue(lower_bound and upper_bound) diff --git a/tests/test_modeling_tf_mobilebert.py b/tests/test_modeling_tf_mobilebert.py new file mode 100644 index 00000000000..5b3c6b820db --- /dev/null +++ b/tests/test_modeling_tf_mobilebert.py @@ -0,0 +1,321 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +from transformers import MobileBertConfig, is_tf_available + +from .test_configuration_common import ConfigTester +from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor +from .utils import require_tf, slow + + +if is_tf_available(): + import tensorflow as tf + from transformers.modeling_tf_mobilebert import ( + TFMobileBertModel, + TFMobileBertForMaskedLM, + TFMobileBertForNextSentencePrediction, + TFMobileBertForPreTraining, + TFMobileBertForSequenceClassification, + TFMobileBertForMultipleChoice, + TFMobileBertForTokenClassification, + TFMobileBertForQuestionAnswering, + ) + + +@require_tf +class TFMobileBertModelTest(TFModelTesterMixin, unittest.TestCase): + + all_model_classes = ( + ( + TFMobileBertModel, + TFMobileBertForMaskedLM, + TFMobileBertForNextSentencePrediction, + TFMobileBertForPreTraining, + TFMobileBertForQuestionAnswering, + TFMobileBertForSequenceClassification, + TFMobileBertForTokenClassification, + TFMobileBertForMultipleChoice, + ) + if is_tf_available() + else () + ) + + class TFMobileBertModelTester(object): + def __init__( + self, + parent, + batch_size=13, + seq_length=7, + is_training=True, + use_input_mask=True, + use_token_type_ids=True, + use_labels=True, + vocab_size=99, + hidden_size=32, + embedding_size=32, + num_hidden_layers=5, + num_attention_heads=4, + intermediate_size=37, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=16, + type_sequence_label_size=2, + initializer_range=0.02, + num_labels=3, + num_choices=4, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_input_mask = use_input_mask + self.use_token_type_ids = use_token_type_ids + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.num_labels = num_labels + self.num_choices = num_choices + self.scope = scope + self.embedding_size = embedding_size + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) + + token_type_ids = None + if self.use_token_type_ids: + token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) + + sequence_labels = None + token_labels = None + choice_labels = None + if self.use_labels: + sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) + choice_labels = ids_tensor([self.batch_size], self.num_choices) + + config = MobileBertConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + initializer_range=self.initializer_range, + embedding_size=self.embedding_size, + ) + + return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + + def create_and_check_mobilebert_model( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = TFMobileBertModel(config=config) + inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids} + sequence_output, pooled_output = model(inputs) + + inputs = [input_ids, input_mask] + sequence_output, pooled_output = model(inputs) + + sequence_output, pooled_output = model(input_ids) + + result = { + "sequence_output": sequence_output.numpy(), + "pooled_output": pooled_output.numpy(), + } + self.parent.assertListEqual( + list(result["sequence_output"].shape), [self.batch_size, self.seq_length, self.hidden_size] + ) + self.parent.assertListEqual(list(result["pooled_output"].shape), [self.batch_size, self.hidden_size]) + + def create_and_check_mobilebert_for_masked_lm( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = TFMobileBertForMaskedLM(config=config) + inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids} + (prediction_scores,) = model(inputs) + result = { + "prediction_scores": prediction_scores.numpy(), + } + self.parent.assertListEqual( + list(result["prediction_scores"].shape), [self.batch_size, self.seq_length, self.vocab_size] + ) + + def create_and_check_mobilebert_for_next_sequence_prediction( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = TFMobileBertForNextSentencePrediction(config=config) + inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids} + (seq_relationship_score,) = model(inputs) + result = { + "seq_relationship_score": seq_relationship_score.numpy(), + } + self.parent.assertListEqual(list(result["seq_relationship_score"].shape), [self.batch_size, 2]) + + def create_and_check_mobilebert_for_pretraining( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = TFMobileBertForPreTraining(config=config) + inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids} + prediction_scores, seq_relationship_score = model(inputs) + result = { + "prediction_scores": prediction_scores.numpy(), + "seq_relationship_score": seq_relationship_score.numpy(), + } + self.parent.assertListEqual( + list(result["prediction_scores"].shape), [self.batch_size, self.seq_length, self.vocab_size] + ) + self.parent.assertListEqual(list(result["seq_relationship_score"].shape), [self.batch_size, 2]) + + def create_and_check_mobilebert_for_sequence_classification( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + config.num_labels = self.num_labels + model = TFMobileBertForSequenceClassification(config=config) + inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids} + (logits,) = model(inputs) + result = { + "logits": logits.numpy(), + } + self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_labels]) + + def create_and_check_mobilebert_for_multiple_choice( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + config.num_choices = self.num_choices + model = TFMobileBertForMultipleChoice(config=config) + multiple_choice_inputs_ids = tf.tile(tf.expand_dims(input_ids, 1), (1, self.num_choices, 1)) + multiple_choice_input_mask = tf.tile(tf.expand_dims(input_mask, 1), (1, self.num_choices, 1)) + multiple_choice_token_type_ids = tf.tile(tf.expand_dims(token_type_ids, 1), (1, self.num_choices, 1)) + inputs = { + "input_ids": multiple_choice_inputs_ids, + "attention_mask": multiple_choice_input_mask, + "token_type_ids": multiple_choice_token_type_ids, + } + (logits,) = model(inputs) + result = { + "logits": logits.numpy(), + } + self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_choices]) + + def create_and_check_mobilebert_for_token_classification( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + config.num_labels = self.num_labels + model = TFMobileBertForTokenClassification(config=config) + inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids} + (logits,) = model(inputs) + result = { + "logits": logits.numpy(), + } + self.parent.assertListEqual( + list(result["logits"].shape), [self.batch_size, self.seq_length, self.num_labels] + ) + + def create_and_check_mobilebert_for_question_answering( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = TFMobileBertForQuestionAnswering(config=config) + inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids} + start_logits, end_logits = model(inputs) + result = { + "start_logits": start_logits.numpy(), + "end_logits": end_logits.numpy(), + } + self.parent.assertListEqual(list(result["start_logits"].shape), [self.batch_size, self.seq_length]) + self.parent.assertListEqual(list(result["end_logits"].shape), [self.batch_size, self.seq_length]) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = config_and_inputs + inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask} + return config, inputs_dict + + def setUp(self): + self.model_tester = TFMobileBertModelTest.TFMobileBertModelTester(self) + self.config_tester = ConfigTester(self, config_class=MobileBertConfig, hidden_size=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_mobilebert_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_mobilebert_model(*config_and_inputs) + + def test_for_masked_lm(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_mobilebert_for_masked_lm(*config_and_inputs) + + def test_for_multiple_choice(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_mobilebert_for_multiple_choice(*config_and_inputs) + + def test_for_next_sequence_prediction(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_mobilebert_for_next_sequence_prediction(*config_and_inputs) + + def test_for_pretraining(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_mobilebert_for_pretraining(*config_and_inputs) + + def test_for_question_answering(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_mobilebert_for_question_answering(*config_and_inputs) + + def test_for_sequence_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_mobilebert_for_sequence_classification(*config_and_inputs) + + def test_for_token_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_mobilebert_for_token_classification(*config_and_inputs) + + @slow + def test_model_from_pretrained(self): + # for model_name in TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: + for model_name in ["mobilebert-uncased"]: + model = TFMobileBertModel.from_pretrained(model_name) + self.assertIsNotNone(model)