mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
ELECTRA (#3257)
* Electra wip * helpers * Electra wip * Electra v1 * ELECTRA may be saved/loaded * Generator & Discriminator * Embedding size instead of halving the hidden size * ELECTRA Tokenizer * Revert BERT helpers * ELECTRA Conversion script * Archive maps * PyTorch tests * Start fixing tests * Tests pass * Same configuration for both models * Compatible with base + large * Simplification + weight tying * Archives * Auto + Renaming to standard names * ELECTRA is uncased * Tests * Slight API changes * Update tests * wip * ElectraForTokenClassification * temp * Simpler arch + tests Removed ElectraForPreTraining which will be in a script * Conversion script * Auto model * Update links to S3 * Split ElectraForPreTraining and ElectraForTokenClassification * Actually test PreTraining model * Remove num_labels from configuration * wip * wip * From discriminator and generator to electra * Slight API changes * Better naming * TensorFlow ELECTRA tests * Accurate conversion script * Added to conversion script * Fast ELECTRA tokenizer * Style * Add ELECTRA to README * Modeling Pytorch Doc + Real style * TF Docs * Docs * Correct links * Correct model intialized * random fixes * style * Addressing Patrick's and Sam's comments * Correct links in docs
This commit is contained in:
parent
8594dd80dd
commit
d5d7d88612
@ -164,8 +164,9 @@ At some point in the future, you'll be able to seamlessly move from pre-training
|
||||
14. **[MMBT](https://github.com/facebookresearch/mmbt/)** (from Facebook), released together with the paper a [Supervised Multimodal Bitransformers for Classifying Images and Text](https://arxiv.org/pdf/1909.02950.pdf) by Douwe Kiela, Suvrat Bhooshan, Hamed Firooz, Davide Testuggine.
|
||||
15. **[FlauBERT](https://github.com/getalp/Flaubert)** (from CNRS) released with the paper [FlauBERT: Unsupervised Language Model Pre-training for French](https://arxiv.org/abs/1912.05372) by Hang Le, Loïc Vial, Jibril Frej, Vincent Segonne, Maximin Coavoux, Benjamin Lecouteux, Alexandre Allauzen, Benoît Crabbé, Laurent Besacier, Didier Schwab.
|
||||
16. **[BART](https://github.com/pytorch/fairseq/tree/master/examples/bart)** (from Facebook) released with the paper [BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension](https://arxiv.org/pdf/1910.13461.pdf) by Mike Lewis, Yinhan Liu, Naman Goyal, Marjan Ghazvininejad, Abdelrahman Mohamed, Omer Levy, Ves Stoyanov and Luke Zettlemoyer.
|
||||
17. **[Other community models](https://huggingface.co/models)**, contributed by the [community](https://huggingface.co/users).
|
||||
18. Want to contribute a new model? We have added a **detailed guide and templates** to guide you in the process of adding a new model. You can find them in the [`templates`](./templates) folder of the repository. Be sure to check the [contributing guidelines](./CONTRIBUTING.md) and contact the maintainers or open an issue to collect feedbacks before starting your PR.
|
||||
17. **[ELECTRA](https://github.com/google-research/electra)** (from Google Research/Stanford University) released with the paper [ELECTRA: Pre-training text encoders as discriminators rather than generators](https://arxiv.org/abs/2003.10555) by Kevin Clark, Minh-Thang Luong, Quoc V. Le, Christopher D. Manning.
|
||||
18. **[Other community models](https://huggingface.co/models)**, contributed by the [community](https://huggingface.co/users).
|
||||
19. Want to contribute a new model? We have added a **detailed guide and templates** to guide you in the process of adding a new model. You can find them in the [`templates`](./templates) folder of the repository. Be sure to check the [contributing guidelines](./CONTRIBUTING.md) and contact the maintainers or open an issue to collect feedbacks before starting your PR.
|
||||
|
||||
These implementations have been tested on several datasets (see the example scripts) and should match the performances of the original implementations (e.g. ~93 F1 on SQuAD for BERT Whole-Word-Masking, ~88 F1 on RocStories for OpenAI GPT, ~18.3 perplexity on WikiText 103 for Transformer-XL, ~0.916 Peason R coefficient on STS-B for XLNet). You can find more details on the performances in the Examples section of the [documentation](https://huggingface.co/transformers/examples.html).
|
||||
|
||||
|
@ -104,3 +104,4 @@ The library currently contains PyTorch and Tensorflow implementations, pre-train
|
||||
model_doc/flaubert
|
||||
model_doc/bart
|
||||
model_doc/t5
|
||||
model_doc/electra
|
115
docs/source/model_doc/electra.rst
Normal file
115
docs/source/model_doc/electra.rst
Normal file
@ -0,0 +1,115 @@
|
||||
ELECTRA
|
||||
----------------------------------------------------
|
||||
|
||||
The ELECTRA model was proposed in the paper.
|
||||
`ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators <https://openreview.net/pdf?id=r1xMH1BtvB>`__.
|
||||
ELECTRA is a new pre-training approach which trains two transformer models: the generator and the discriminator. The
|
||||
generator's role is to replace tokens in a sequence, and is therefore trained as a masked language model. The discriminator,
|
||||
which is the model we're interested in, tries to identify which tokens were replaced by the generator in the sequence.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*Masked language modeling (MLM) pre-training methods such as BERT corrupt
|
||||
the input by replacing some tokens with [MASK] and then train a model to
|
||||
reconstruct the original tokens. While they produce good results when transferred
|
||||
to downstream NLP tasks, they generally require large amounts of compute to be
|
||||
effective. As an alternative, we propose a more sample-efficient pre-training task
|
||||
called replaced token detection. Instead of masking the input, our approach
|
||||
corrupts it by replacing some tokens with plausible alternatives sampled from a small
|
||||
generator network. Then, instead of training a model that predicts the original
|
||||
identities of the corrupted tokens, we train a discriminative model that predicts
|
||||
whether each token in the corrupted input was replaced by a generator sample
|
||||
or not. Thorough experiments demonstrate this new pre-training task is more
|
||||
efficient than MLM because the task is defined over all input tokens rather than
|
||||
just the small subset that was masked out. As a result, the contextual representations
|
||||
learned by our approach substantially outperform the ones learned by BERT
|
||||
given the same model size, data, and compute. The gains are particularly strong
|
||||
for small models; for example, we train a model on one GPU for 4 days that
|
||||
outperforms GPT (trained using 30x more compute) on the GLUE natural language
|
||||
understanding benchmark. Our approach also works well at scale, where it
|
||||
performs comparably to RoBERTa and XLNet while using less than 1/4 of their
|
||||
compute and outperforms them when using the same amount of compute.*
|
||||
|
||||
Tips:
|
||||
|
||||
- ELECTRA is the pre-training approach, therefore there is nearly no changes done to the underlying model: BERT. The
|
||||
only change is the separation of the embedding size and the hidden size -> The embedding size is generally smaller,
|
||||
while the hidden size is larger. An additional projection layer (linear) is used to project the embeddings from
|
||||
their embedding size to the hidden size. In the case where the embedding size is the same as the hidden size, no
|
||||
projection layer is used.
|
||||
- The ELECTRA checkpoints saved using `Google Research's implementation <https://github.com/google-research/electra>`__
|
||||
contain both the generator and discriminator. The conversion script requires the user to name which model to export
|
||||
into the correct architecture. Once converted to the HuggingFace format, these checkpoints may be loaded into all
|
||||
available ELECTRA models, however. This means that the discriminator may be loaded in the `ElectraForMaskedLM` model,
|
||||
and the generator may be loaded in the `ElectraForPreTraining` model (the classification head will be randomly
|
||||
initialized as it doesn't exist in the generator).
|
||||
|
||||
|
||||
ElectraConfig
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.ElectraConfig
|
||||
:members:
|
||||
|
||||
|
||||
ElectraTokenizer
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.ElectraTokenizer
|
||||
:members:
|
||||
|
||||
|
||||
ElectraModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.ElectraModel
|
||||
:members:
|
||||
|
||||
|
||||
ElectraForPreTraining
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.ElectraForPreTraining
|
||||
:members:
|
||||
|
||||
|
||||
ElectraForMaskedLM
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.ElectraForMaskedLM
|
||||
:members:
|
||||
|
||||
|
||||
ElectraForTokenClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.ElectraForTokenClassification
|
||||
:members:
|
||||
|
||||
|
||||
TFElectraModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFElectraModel
|
||||
:members:
|
||||
|
||||
|
||||
TFElectraForPreTraining
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFElectraForPreTraining
|
||||
:members:
|
||||
|
||||
|
||||
TFElectraForMaskedLM
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFElectraForMaskedLM
|
||||
:members:
|
||||
|
||||
|
||||
TFElectraForTokenClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFElectraForTokenClassification
|
||||
:members:
|
@ -38,6 +38,7 @@ from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig
|
||||
from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig
|
||||
from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig
|
||||
from .configuration_distilbert import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig
|
||||
from .configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig
|
||||
from .configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FlaubertConfig
|
||||
from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
|
||||
from .configuration_mmbt import MMBTConfig
|
||||
@ -127,6 +128,7 @@ from .tokenization_bert_japanese import BertJapaneseTokenizer, CharacterTokenize
|
||||
from .tokenization_camembert import CamembertTokenizer
|
||||
from .tokenization_ctrl import CTRLTokenizer
|
||||
from .tokenization_distilbert import DistilBertTokenizer, DistilBertTokenizerFast
|
||||
from .tokenization_electra import ElectraTokenizer, ElectraTokenizerFast
|
||||
from .tokenization_flaubert import FlaubertTokenizer
|
||||
from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast
|
||||
from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast
|
||||
@ -297,6 +299,15 @@ if is_torch_available():
|
||||
FLAUBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
)
|
||||
|
||||
from .modeling_electra import (
|
||||
ElectraForPreTraining,
|
||||
ElectraForMaskedLM,
|
||||
ElectraForTokenClassification,
|
||||
ElectraModel,
|
||||
load_tf_weights_in_electra,
|
||||
ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
)
|
||||
|
||||
# Optimization
|
||||
from .optimization import (
|
||||
AdamW,
|
||||
@ -463,6 +474,15 @@ if is_tf_available():
|
||||
TF_T5_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
)
|
||||
|
||||
from .modeling_tf_electra import (
|
||||
TFElectraPreTrainedModel,
|
||||
TFElectraModel,
|
||||
TFElectraForPreTraining,
|
||||
TFElectraForMaskedLM,
|
||||
TFElectraForTokenClassification,
|
||||
TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
)
|
||||
|
||||
# Optimization
|
||||
from .optimization_tf import WarmUp, create_optimizer, AdamWeightDecay, GradientAccumulator
|
||||
|
||||
|
@ -24,6 +24,7 @@ from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig
|
||||
from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig
|
||||
from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig
|
||||
from .configuration_distilbert import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig
|
||||
from .configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig
|
||||
from .configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FlaubertConfig
|
||||
from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
|
||||
from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
|
||||
@ -57,6 +58,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
|
||||
T5_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
]
|
||||
for key, value, in pretrained_map.items()
|
||||
)
|
||||
@ -79,6 +81,7 @@ CONFIG_MAPPING = OrderedDict(
|
||||
("xlnet", XLNetConfig,),
|
||||
("xlm", XLMConfig,),
|
||||
("ctrl", CTRLConfig,),
|
||||
("electra", ElectraConfig,),
|
||||
]
|
||||
)
|
||||
|
||||
@ -133,6 +136,7 @@ class AutoConfig:
|
||||
- contains `xlm`: :class:`~transformers.XLMConfig` (XLM model)
|
||||
- contains `ctrl` : :class:`~transformers.CTRLConfig` (CTRL model)
|
||||
- contains `flaubert` : :class:`~transformers.FlaubertConfig` (Flaubert model)
|
||||
- contains `electra` : :class:`~transformers.ElectraConfig` (ELECTRA model)
|
||||
|
||||
|
||||
Args:
|
||||
|
132
src/transformers/configuration_electra.py
Normal file
132
src/transformers/configuration_electra.py
Normal file
@ -0,0 +1,132 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" ELECTRA model configuration """
|
||||
|
||||
|
||||
import logging
|
||||
|
||||
from .configuration_utils import PretrainedConfig
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
"google/electra-small-generator": "https://s3.amazonaws.com/models.huggingface.co/bert/google/electra-small-generator/config.json",
|
||||
"google/electra-base-generator": "https://s3.amazonaws.com/models.huggingface.co/bert/google/electra-base-generator/config.json",
|
||||
"google/electra-large-generator": "https://s3.amazonaws.com/models.huggingface.co/bert/google/electra-large-generator/config.json",
|
||||
"google/electra-small-discriminator": "https://s3.amazonaws.com/models.huggingface.co/bert/google/electra-small-discriminator/config.json",
|
||||
"google/electra-base-discriminator": "https://s3.amazonaws.com/models.huggingface.co/bert/google/electra-base-discriminator/config.json",
|
||||
"google/electra-large-discriminator": "https://s3.amazonaws.com/models.huggingface.co/bert/google/electra-large-discriminator/config.json",
|
||||
}
|
||||
|
||||
|
||||
class ElectraConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a :class:`~transformers.ElectraModel`.
|
||||
It is used to instantiate an ELECTRA model according to the specified arguments, defining the model
|
||||
architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
|
||||
the ELECTRA `google/electra-small-discriminator <https://huggingface.co/google/electra-small-discriminator>`__
|
||||
architecture.
|
||||
|
||||
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used
|
||||
to control the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig`
|
||||
for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (:obj:`int`, optional, defaults to 30522):
|
||||
Vocabulary size of the ELECTRA model. Defines the different tokens that
|
||||
can be represented by the `inputs_ids` passed to the forward method of :class:`~transformers.ElectraModel`.
|
||||
embedding_size (:obj:`int`, optional, defaults to 128):
|
||||
Dimensionality of the encoder layers and the pooler layer.
|
||||
hidden_size (:obj:`int`, optional, defaults to 256):
|
||||
Dimensionality of the encoder layers and the pooler layer.
|
||||
num_hidden_layers (:obj:`int`, optional, defaults to 12):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (:obj:`int`, optional, defaults to 4):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
intermediate_size (:obj:`int`, optional, defaults to 1024):
|
||||
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
||||
hidden_act (:obj:`str` or :obj:`function`, optional, defaults to "gelu"):
|
||||
The non-linear activation function (function or string) in the encoder and pooler.
|
||||
If string, "gelu", "relu", "swish" and "gelu_new" are supported.
|
||||
hidden_dropout_prob (:obj:`float`, optional, defaults to 0.1):
|
||||
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
attention_probs_dropout_prob (:obj:`float`, optional, defaults to 0.1):
|
||||
The dropout ratio for the attention probabilities.
|
||||
max_position_embeddings (:obj:`int`, optional, defaults to 512):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
|
||||
type_vocab_size (:obj:`int`, optional, defaults to 2):
|
||||
The vocabulary size of the `token_type_ids` passed into :class:`~transformers.ElectraModel`.
|
||||
initializer_range (:obj:`float`, optional, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
layer_norm_eps (:obj:`float`, optional, defaults to 1e-12):
|
||||
The epsilon used by the layer normalization layers.
|
||||
|
||||
Example::
|
||||
|
||||
from transformers import ElectraModel, ElectraConfig
|
||||
|
||||
# Initializing a ELECTRA electra-base-uncased style configuration
|
||||
configuration = ElectraConfig()
|
||||
|
||||
# Initializing a model from the electra-base-uncased style configuration
|
||||
model = ElectraModel(configuration)
|
||||
|
||||
# Accessing the model configuration
|
||||
configuration = model.config
|
||||
|
||||
Attributes:
|
||||
pretrained_config_archive_map (Dict[str, str]):
|
||||
A dictionary containing all the available pre-trained checkpoints.
|
||||
"""
|
||||
pretrained_config_archive_map = ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
model_type = "electra"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=30522,
|
||||
embedding_size=128,
|
||||
hidden_size=256,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=1024,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
layer_norm_eps=1e-12,
|
||||
pad_token_id=0,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.embedding_size = embedding_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.initializer_range = initializer_range
|
||||
self.layer_norm_eps = layer_norm_eps
|
@ -0,0 +1,79 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert ELECTRA checkpoint."""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import ElectraConfig, ElectraForMaskedLM, ElectraForPreTraining, load_tf_weights_in_electra
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path, discriminator_or_generator):
|
||||
# Initialise PyTorch model
|
||||
config = ElectraConfig.from_json_file(config_file)
|
||||
print("Building PyTorch model from configuration: {}".format(str(config)))
|
||||
|
||||
if discriminator_or_generator == "discriminator":
|
||||
model = ElectraForPreTraining(config)
|
||||
elif discriminator_or_generator == "generator":
|
||||
model = ElectraForMaskedLM(config)
|
||||
else:
|
||||
raise ValueError("The discriminator_or_generator argument should be either 'discriminator' or 'generator'")
|
||||
|
||||
# Load weights from tf checkpoint
|
||||
load_tf_weights_in_electra(
|
||||
model, config, tf_checkpoint_path, discriminator_or_generator=discriminator_or_generator
|
||||
)
|
||||
|
||||
# Save pytorch-model
|
||||
print("Save PyTorch model to {}".format(pytorch_dump_path))
|
||||
torch.save(model.state_dict(), pytorch_dump_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_file",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The config json file corresponding to the pre-trained model. \n"
|
||||
"This specifies the model architecture.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--discriminator_or_generator",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Whether to export the generator or the discriminator. Should be a string, either 'discriminator' or "
|
||||
"'generator'.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_tf_checkpoint_to_pytorch(
|
||||
args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path, args.discriminator_or_generator
|
||||
)
|
@ -25,6 +25,7 @@ from transformers import (
|
||||
CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
@ -39,6 +40,7 @@ from transformers import (
|
||||
CamembertConfig,
|
||||
CTRLConfig,
|
||||
DistilBertConfig,
|
||||
ElectraConfig,
|
||||
FlaubertConfig,
|
||||
GPT2Config,
|
||||
OpenAIGPTConfig,
|
||||
@ -52,6 +54,7 @@ from transformers import (
|
||||
TFCTRLLMHeadModel,
|
||||
TFDistilBertForMaskedLM,
|
||||
TFDistilBertForQuestionAnswering,
|
||||
TFElectraForPreTraining,
|
||||
TFFlaubertWithLMHeadModel,
|
||||
TFGPT2LMHeadModel,
|
||||
TFOpenAIGPTLMHeadModel,
|
||||
@ -110,6 +113,8 @@ if is_torch_available():
|
||||
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
T5ForConditionalGeneration,
|
||||
T5_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
ElectraForPreTraining,
|
||||
ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
)
|
||||
else:
|
||||
(
|
||||
@ -147,6 +152,8 @@ else:
|
||||
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
T5ForConditionalGeneration,
|
||||
T5_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
ElectraForPreTraining,
|
||||
ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
) = (
|
||||
None,
|
||||
None,
|
||||
@ -182,6 +189,8 @@ else:
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
@ -321,6 +330,13 @@ MODEL_CLASSES = {
|
||||
T5_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
T5_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
),
|
||||
"electra": (
|
||||
ElectraConfig,
|
||||
TFElectraForPreTraining,
|
||||
ElectraForPreTraining,
|
||||
ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
|
@ -26,6 +26,7 @@ from .configuration_auto import (
|
||||
CamembertConfig,
|
||||
CTRLConfig,
|
||||
DistilBertConfig,
|
||||
ElectraConfig,
|
||||
FlaubertConfig,
|
||||
GPT2Config,
|
||||
OpenAIGPTConfig,
|
||||
@ -76,6 +77,13 @@ from .modeling_distilbert import (
|
||||
DistilBertForTokenClassification,
|
||||
DistilBertModel,
|
||||
)
|
||||
from .modeling_electra import (
|
||||
ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
ElectraForMaskedLM,
|
||||
ElectraForPreTraining,
|
||||
ElectraForTokenClassification,
|
||||
ElectraModel,
|
||||
)
|
||||
from .modeling_flaubert import (
|
||||
FLAUBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
FlaubertForQuestionAnsweringSimple,
|
||||
@ -141,6 +149,7 @@ ALL_PRETRAINED_MODEL_ARCHIVE_MAP = dict(
|
||||
T5_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
FLAUBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
]
|
||||
for key, value, in pretrained_map.items()
|
||||
)
|
||||
@ -162,6 +171,7 @@ MODEL_MAPPING = OrderedDict(
|
||||
(FlaubertConfig, FlaubertModel),
|
||||
(XLMConfig, XLMModel),
|
||||
(CTRLConfig, CTRLModel),
|
||||
(ElectraConfig, ElectraModel),
|
||||
]
|
||||
)
|
||||
|
||||
@ -182,6 +192,7 @@ MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
|
||||
(FlaubertConfig, FlaubertWithLMHeadModel),
|
||||
(XLMConfig, XLMWithLMHeadModel),
|
||||
(CTRLConfig, CTRLLMHeadModel),
|
||||
(ElectraConfig, ElectraForPreTraining),
|
||||
]
|
||||
)
|
||||
|
||||
@ -202,6 +213,7 @@ MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
|
||||
(FlaubertConfig, FlaubertWithLMHeadModel),
|
||||
(XLMConfig, XLMWithLMHeadModel),
|
||||
(CTRLConfig, CTRLLMHeadModel),
|
||||
(ElectraConfig, ElectraForMaskedLM),
|
||||
]
|
||||
)
|
||||
|
||||
@ -242,6 +254,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(
|
||||
(BertConfig, BertForTokenClassification),
|
||||
(XLNetConfig, XLNetForTokenClassification),
|
||||
(AlbertConfig, AlbertForTokenClassification),
|
||||
(ElectraConfig, ElectraForTokenClassification),
|
||||
]
|
||||
)
|
||||
|
||||
@ -281,7 +294,8 @@ class AutoModel(object):
|
||||
- isInstance of `transfo-xl` configuration class: :class:`~transformers.TransfoXLModel` (Transformer-XL model)
|
||||
- isInstance of `xlnet` configuration class: :class:`~transformers.XLNetModel` (XLNet model)
|
||||
- isInstance of `xlm` configuration class: :class:`~transformers.XLMModel` (XLM model)
|
||||
- isInstance of `flaubert` configuration class: :class:`~transformers.FlaubertModel` (XLM model)
|
||||
- isInstance of `flaubert` configuration class: :class:`~transformers.FlaubertModel` (Flaubert model)
|
||||
- isInstance of `electra` configuration class: :class:`~transformers.ElectraModel` (Electra model)
|
||||
|
||||
Examples::
|
||||
|
||||
@ -322,7 +336,8 @@ class AutoModel(object):
|
||||
- contains `xlnet`: :class:`~transformers.XLNetModel` (XLNet model)
|
||||
- contains `xlm`: :class:`~transformers.XLMModel` (XLM model)
|
||||
- contains `ctrl`: :class:`~transformers.CTRLModel` (Salesforce CTRL model)
|
||||
- contains `flaubert`: :class:`~transformers.Flaubert` (Flaubert model)
|
||||
- contains `flaubert`: :class:`~transformers.FlaubertModel` (Flaubert model)
|
||||
- contains `electra`: :class:`~transformers.ElectraModel` (Electra model)
|
||||
|
||||
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
|
||||
To train the model, you should first set it back in training mode with `model.train()`
|
||||
@ -430,6 +445,7 @@ class AutoModelForPreTraining(object):
|
||||
- isInstance of `xlnet` configuration class: :class:`~transformers.XLNetLMHeadModel` (XLNet model)
|
||||
- isInstance of `xlm` configuration class: :class:`~transformers.XLMWithLMHeadModel` (XLM model)
|
||||
- isInstance of `flaubert` configuration class: :class:`~transformers.FlaubertWithLMHeadModel` (Flaubert model)
|
||||
- isInstance of `electra` configuration class: :class:`~transformers.ElectraForPreTraining` (Electra model)
|
||||
|
||||
Examples::
|
||||
|
||||
@ -470,6 +486,7 @@ class AutoModelForPreTraining(object):
|
||||
- contains `xlm`: :class:`~transformers.XLMWithLMHeadModel` (XLM model)
|
||||
- contains `ctrl`: :class:`~transformers.CTRLLMHeadModel` (Salesforce CTRL model)
|
||||
- contains `flaubert`: :class:`~transformers.FlaubertWithLMHeadModel` (Flaubert model)
|
||||
- contains `electra`: :class:`~transformers.ElectraForPreTraining` (Electra model)
|
||||
|
||||
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
|
||||
To train the model, you should first set it back in training mode with `model.train()`
|
||||
@ -571,6 +588,7 @@ class AutoModelWithLMHead(object):
|
||||
- isInstance of `xlnet` configuration class: :class:`~transformers.XLNetLMHeadModel` (XLNet model)
|
||||
- isInstance of `xlm` configuration class: :class:`~transformers.XLMWithLMHeadModel` (XLM model)
|
||||
- isInstance of `flaubert` configuration class: :class:`~transformers.FlaubertWithLMHeadModel` (Flaubert model)
|
||||
- isInstance of `electra` configuration class: :class:`~transformers.ElectraForMaskedLM` (Electra model)
|
||||
|
||||
Examples::
|
||||
|
||||
@ -612,6 +630,7 @@ class AutoModelWithLMHead(object):
|
||||
- contains `xlm`: :class:`~transformers.XLMWithLMHeadModel` (XLM model)
|
||||
- contains `ctrl`: :class:`~transformers.CTRLLMHeadModel` (Salesforce CTRL model)
|
||||
- contains `flaubert`: :class:`~transformers.FlaubertWithLMHeadModel` (Flaubert model)
|
||||
- contains `electra`: :class:`~transformers.ElectraForMaskedLM` (Electra model)
|
||||
|
||||
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
|
||||
To train the model, you should first set it back in training mode with `model.train()`
|
||||
@ -998,6 +1017,7 @@ class AutoModelForTokenClassification:
|
||||
- isInstance of `xlnet` configuration class: :class:`~transformers.XLNetModelForTokenClassification` (XLNet model)
|
||||
- isInstance of `camembert` configuration class: :class:`~transformers.CamembertModelForTokenClassification` (Camembert model)
|
||||
- isInstance of `roberta` configuration class: :class:`~transformers.RobertaModelForTokenClassification` (Roberta model)
|
||||
- isInstance of `electra` configuration class: :class:`~transformers.ElectraForTokenClassification` (Electra model)
|
||||
|
||||
Examples::
|
||||
|
||||
@ -1035,6 +1055,7 @@ class AutoModelForTokenClassification:
|
||||
- contains `bert`: :class:`~transformers.BertForTokenClassification` (Bert model)
|
||||
- contains `xlnet`: :class:`~transformers.XLNetForTokenClassification` (XLNet model)
|
||||
- contains `roberta`: :class:`~transformers.RobertaForTokenClassification` (Roberta model)
|
||||
- contains `electra`: :class:`~transformers.ElectraForTokenClassification` (Electra model)
|
||||
|
||||
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
|
||||
To train the model, you should first set it back in training mode with `model.train()`
|
||||
|
@ -183,7 +183,7 @@ class BertEmbeddings(nn.Module):
|
||||
class BertSelfAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
if config.hidden_size % config.num_attention_heads != 0:
|
||||
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
||||
raise ValueError(
|
||||
"The hidden size (%d) is not a multiple of the number of attention "
|
||||
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
|
||||
|
671
src/transformers/modeling_electra.py
Normal file
671
src/transformers/modeling_electra.py
Normal file
@ -0,0 +1,671 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from transformers import ElectraConfig, add_start_docstrings
|
||||
from transformers.activations import get_activation
|
||||
|
||||
from .file_utils import add_start_docstrings_to_callable
|
||||
from .modeling_bert import BertEmbeddings, BertEncoder, BertLayerNorm, BertPreTrainedModel
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
"google/electra-small-generator": "https://s3.amazonaws.com/models.huggingface.co/bert/google/electra-small-generator/pytorch_model.bin",
|
||||
"google/electra-base-generator": "https://s3.amazonaws.com/models.huggingface.co/bert/google/electra-base-generator/pytorch_model.bin",
|
||||
"google/electra-large-generator": "https://s3.amazonaws.com/models.huggingface.co/bert/google/electra-large-generator/pytorch_model.bin",
|
||||
"google/electra-small-discriminator": "https://s3.amazonaws.com/models.huggingface.co/bert/google/electra-small-discriminator/pytorch_model.bin",
|
||||
"google/electra-base-discriminator": "https://s3.amazonaws.com/models.huggingface.co/bert/google/electra-base-discriminator/pytorch_model.bin",
|
||||
"google/electra-large-discriminator": "https://s3.amazonaws.com/models.huggingface.co/bert/google/electra-large-discriminator/pytorch_model.bin",
|
||||
}
|
||||
|
||||
|
||||
def load_tf_weights_in_electra(model, config, tf_checkpoint_path, discriminator_or_generator="discriminator"):
|
||||
""" Load tf checkpoints in a pytorch model.
|
||||
"""
|
||||
try:
|
||||
import re
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
except ImportError:
|
||||
logger.error(
|
||||
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
|
||||
"https://www.tensorflow.org/install/ for installation instructions."
|
||||
)
|
||||
raise
|
||||
tf_path = os.path.abspath(tf_checkpoint_path)
|
||||
logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
|
||||
# Load weights from TF model
|
||||
init_vars = tf.train.list_variables(tf_path)
|
||||
names = []
|
||||
arrays = []
|
||||
for name, shape in init_vars:
|
||||
logger.info("Loading TF weight {} with shape {}".format(name, shape))
|
||||
array = tf.train.load_variable(tf_path, name)
|
||||
names.append(name)
|
||||
arrays.append(array)
|
||||
for name, array in zip(names, arrays):
|
||||
original_name: str = name
|
||||
|
||||
try:
|
||||
if isinstance(model, ElectraForMaskedLM):
|
||||
name = name.replace("electra/embeddings/", "generator/embeddings/")
|
||||
|
||||
if discriminator_or_generator == "generator":
|
||||
name = name.replace("electra/", "discriminator/")
|
||||
name = name.replace("generator/", "electra/")
|
||||
|
||||
name = name.replace("dense_1", "dense_prediction")
|
||||
name = name.replace("generator_predictions/output_bias", "generator_lm_head/bias")
|
||||
|
||||
name = name.split("/")
|
||||
# print(original_name, name)
|
||||
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
||||
# which are not required for using pretrained model
|
||||
if any(n in ["global_step", "temperature"] for n in name):
|
||||
logger.info("Skipping {}".format(original_name))
|
||||
continue
|
||||
pointer = model
|
||||
for m_name in name:
|
||||
if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
|
||||
scope_names = re.split(r"_(\d+)", m_name)
|
||||
else:
|
||||
scope_names = [m_name]
|
||||
if scope_names[0] == "kernel" or scope_names[0] == "gamma":
|
||||
pointer = getattr(pointer, "weight")
|
||||
elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
|
||||
pointer = getattr(pointer, "bias")
|
||||
elif scope_names[0] == "output_weights":
|
||||
pointer = getattr(pointer, "weight")
|
||||
elif scope_names[0] == "squad":
|
||||
pointer = getattr(pointer, "classifier")
|
||||
else:
|
||||
pointer = getattr(pointer, scope_names[0])
|
||||
if len(scope_names) >= 2:
|
||||
num = int(scope_names[1])
|
||||
pointer = pointer[num]
|
||||
if m_name.endswith("_embeddings"):
|
||||
pointer = getattr(pointer, "weight")
|
||||
elif m_name == "kernel":
|
||||
array = np.transpose(array)
|
||||
try:
|
||||
assert pointer.shape == array.shape, original_name
|
||||
except AssertionError as e:
|
||||
e.args += (pointer.shape, array.shape)
|
||||
raise
|
||||
print("Initialize PyTorch weight {}".format(name), original_name)
|
||||
pointer.data = torch.from_numpy(array)
|
||||
except AttributeError as e:
|
||||
print("Skipping {}".format(original_name), name, e)
|
||||
continue
|
||||
return model
|
||||
|
||||
|
||||
class ElectraEmbeddings(BertEmbeddings):
|
||||
"""Construct the embeddings from word, position and token_type embeddings."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id)
|
||||
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_size)
|
||||
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.embedding_size)
|
||||
|
||||
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
||||
# any TensorFlow checkpoint file
|
||||
self.LayerNorm = BertLayerNorm(config.embedding_size, eps=config.layer_norm_eps)
|
||||
|
||||
|
||||
class ElectraDiscriminatorPredictions(nn.Module):
|
||||
"""Prediction module for the discriminator, made up of two dense layers."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.dense_prediction = nn.Linear(config.hidden_size, 1)
|
||||
self.config = config
|
||||
|
||||
def forward(self, discriminator_hidden_states, attention_mask):
|
||||
hidden_states = self.dense(discriminator_hidden_states)
|
||||
hidden_states = get_activation(self.config.hidden_act)(hidden_states)
|
||||
logits = self.dense_prediction(hidden_states).squeeze()
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
class ElectraGeneratorPredictions(nn.Module):
|
||||
"""Prediction module for the generator, made up of two dense layers."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
self.LayerNorm = BertLayerNorm(config.embedding_size)
|
||||
self.dense = nn.Linear(config.hidden_size, config.embedding_size)
|
||||
|
||||
def forward(self, generator_hidden_states):
|
||||
hidden_states = self.dense(generator_hidden_states)
|
||||
hidden_states = get_activation("gelu")(hidden_states)
|
||||
hidden_states = self.LayerNorm(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ElectraPreTrainedModel(BertPreTrainedModel):
|
||||
""" An abstract class to handle weights initialization and
|
||||
a simple interface for downloading and loading pretrained models.
|
||||
"""
|
||||
|
||||
config_class = ElectraConfig
|
||||
pretrained_model_archive_map = ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
load_tf_weights = load_tf_weights_in_electra
|
||||
base_model_prefix = "electra"
|
||||
|
||||
def get_extended_attention_mask(self, attention_mask, input_shape, device):
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
if attention_mask.dim() == 3:
|
||||
extended_attention_mask = attention_mask[:, None, :, :]
|
||||
elif attention_mask.dim() == 2:
|
||||
# Provided a padding mask of dimensions [batch_size, seq_length]
|
||||
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
||||
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
if self.config.is_decoder:
|
||||
batch_size, seq_length = input_shape
|
||||
seq_ids = torch.arange(seq_length, device=device)
|
||||
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
|
||||
causal_mask = causal_mask.to(
|
||||
attention_mask.dtype
|
||||
) # causal and attention masks must have same type with pytorch version < 1.3
|
||||
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
||||
else:
|
||||
extended_attention_mask = attention_mask[:, None, None, :]
|
||||
else:
|
||||
raise ValueError(
|
||||
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
|
||||
input_shape, attention_mask.shape
|
||||
)
|
||||
)
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
# positions we want to attend and -10000.0 for masked positions.
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
|
||||
return extended_attention_mask
|
||||
|
||||
def get_head_mask(self, head_mask):
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
num_hidden_layers = self.config.num_hidden_layers
|
||||
if head_mask is not None:
|
||||
if head_mask.dim() == 1:
|
||||
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
||||
head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1)
|
||||
elif head_mask.dim() == 2:
|
||||
head_mask = (
|
||||
head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
|
||||
) # We can specify head_mask for each layer
|
||||
head_mask = head_mask.to(
|
||||
dtype=next(self.parameters()).dtype
|
||||
) # switch to fload if need + fp16 compatibility
|
||||
else:
|
||||
head_mask = [None] * num_hidden_layers
|
||||
|
||||
return head_mask
|
||||
|
||||
|
||||
ELECTRA_START_DOCSTRING = r"""
|
||||
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
|
||||
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
|
||||
usage and behavior.
|
||||
|
||||
Parameters:
|
||||
config (:class:`~transformers.ElectraConfig`): Model configuration class with all the parameters of the model.
|
||||
Initializing with a config file does not load the weights associated with the model, only the configuration.
|
||||
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
|
||||
"""
|
||||
|
||||
ELECTRA_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
|
||||
Indices can be obtained using :class:`transformers.ElectraTokenizer`.
|
||||
See :func:`transformers.PreTrainedTokenizer.encode` and
|
||||
:func:`transformers.PreTrainedTokenizer.encode_plus` for details.
|
||||
|
||||
`What are input IDs? <../glossary.html#input-ids>`__
|
||||
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
Mask to avoid performing attention on padding token indices.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
Segment token indices to indicate first and second portions of the inputs.
|
||||
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
|
||||
corresponds to a `sentence B` token
|
||||
|
||||
`What are token type IDs? <../glossary.html#token-type-ids>`_
|
||||
position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
Indices of positions of each input sequence tokens in the position embeddings.
|
||||
Selected in the range ``[0, config.max_position_embeddings - 1]``.
|
||||
|
||||
`What are position IDs? <../glossary.html#position-ids>`_
|
||||
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
|
||||
Mask to nullify selected heads of the self-attention modules.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
:obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**.
|
||||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
|
||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||
than the model's internal embedding lookup matrix.
|
||||
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
|
||||
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
|
||||
if the model is configured as a decoder.
|
||||
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
Mask to avoid performing attention on the padding token indices of the encoder input. This mask
|
||||
is used in the cross-attention if the model is configured as a decoder.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare Electra Model transformer outputting raw hidden-states without any specific head on top. Identical to "
|
||||
"the BERT model except that it uses an additional linear layer between the embedding layer and the encoder if the "
|
||||
"hidden size and embedding size are different."
|
||||
""
|
||||
"Both the generator and discriminator checkpoints may be loaded into this model.",
|
||||
ELECTRA_START_DOCSTRING,
|
||||
)
|
||||
class ElectraModel(ElectraPreTrainedModel):
|
||||
|
||||
config_class = ElectraConfig
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.embeddings = ElectraEmbeddings(config)
|
||||
|
||||
if config.embedding_size != config.hidden_size:
|
||||
self.embeddings_project = nn.Linear(config.embedding_size, config.hidden_size)
|
||||
|
||||
self.encoder = BertEncoder(config)
|
||||
self.config = config
|
||||
self.init_weights()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings.word_embeddings
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.embeddings.word_embeddings = value
|
||||
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
""" Prunes heads of the model.
|
||||
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
||||
See base class PreTrainedModel
|
||||
"""
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
):
|
||||
r"""
|
||||
Return:
|
||||
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.ElectraConfig`) and inputs:
|
||||
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
||||
Sequence of hidden-states at the output of the last layer of the model.
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
||||
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
|
||||
Examples::
|
||||
|
||||
from transformers import ElectraModel, ElectraTokenizer
|
||||
import torch
|
||||
|
||||
tokenizer = ElectraTokenizer.from_pretrained('google/electra-small-discriminator')
|
||||
model = ElectraModel.from_pretrained('google/electra-small-discriminator')
|
||||
|
||||
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
|
||||
outputs = model(input_ids)
|
||||
|
||||
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
||||
|
||||
"""
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(input_shape, device=device)
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||
|
||||
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)
|
||||
head_mask = self.get_head_mask(head_mask)
|
||||
|
||||
hidden_states = self.embeddings(
|
||||
input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
|
||||
)
|
||||
|
||||
if hasattr(self, "embeddings_project"):
|
||||
hidden_states = self.embeddings_project(hidden_states)
|
||||
|
||||
hidden_states = self.encoder(hidden_states, attention_mask=extended_attention_mask, head_mask=head_mask)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
Electra model with a binary classification head on top as used during pre-training for identifying generated
|
||||
tokens.
|
||||
|
||||
It is recommended to load the discriminator checkpoint into that model.""",
|
||||
ELECTRA_START_DOCSTRING,
|
||||
)
|
||||
class ElectraForPreTraining(ElectraPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
self.electra = ElectraModel(config)
|
||||
self.discriminator_predictions = ElectraDiscriminatorPredictions(config)
|
||||
self.init_weights()
|
||||
|
||||
@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
):
|
||||
r"""
|
||||
labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`, defaults to :obj:`None`):
|
||||
Labels for computing the ELECTRA loss. Input should be a sequence of tokens (see :obj:`input_ids` docstring)
|
||||
Indices should be in ``[0, 1]``.
|
||||
``0`` indicates the token is an original token,
|
||||
``1`` indicates the token was replaced.
|
||||
|
||||
Returns:
|
||||
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.ElectraConfig`) and inputs:
|
||||
loss (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
Total loss of the ELECTRA objective.
|
||||
scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`)
|
||||
Prediction scores of the head (scores for each token before SoftMax).
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
||||
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
|
||||
|
||||
Examples::
|
||||
|
||||
from transformers import ElectraTokenizer, ElectraForPreTraining
|
||||
import torch
|
||||
|
||||
tokenizer = ElectraTokenizer.from_pretrained('google/electra-small-discriminator')
|
||||
model = ElectraForPreTraining.from_pretrained('google/electra-small-discriminator')
|
||||
|
||||
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
|
||||
outputs = model(input_ids)
|
||||
|
||||
prediction_scores, seq_relationship_scores = outputs[:2]
|
||||
|
||||
"""
|
||||
|
||||
discriminator_hidden_states = self.electra(
|
||||
input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds
|
||||
)
|
||||
discriminator_sequence_output = discriminator_hidden_states[0]
|
||||
|
||||
logits = self.discriminator_predictions(discriminator_sequence_output, attention_mask)
|
||||
|
||||
output = (logits,)
|
||||
|
||||
if labels is not None:
|
||||
loss_fct = nn.BCEWithLogitsLoss()
|
||||
if attention_mask is not None:
|
||||
active_loss = attention_mask.view(-1, discriminator_sequence_output.shape[1]) == 1
|
||||
active_logits = logits.view(-1, discriminator_sequence_output.shape[1])[active_loss]
|
||||
active_labels = labels[active_loss]
|
||||
loss = loss_fct(active_logits, active_labels.float())
|
||||
else:
|
||||
loss = loss_fct(logits.view(-1, discriminator_sequence_output.shape[1]), labels.float())
|
||||
|
||||
output = (loss,) + output
|
||||
|
||||
output += discriminator_hidden_states[1:]
|
||||
|
||||
return output # (loss), scores, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
Electra model with a language modeling head on top.
|
||||
|
||||
Even though both the discriminator and generator may be loaded into this model, the generator is
|
||||
the only model of the two to have been trained for the masked language modeling task.""",
|
||||
ELECTRA_START_DOCSTRING,
|
||||
)
|
||||
class ElectraForMaskedLM(ElectraPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
self.electra = ElectraModel(config)
|
||||
self.generator_predictions = ElectraGeneratorPredictions(config)
|
||||
|
||||
self.generator_lm_head = nn.Linear(config.embedding_size, config.vocab_size)
|
||||
self.init_weights()
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.generator_lm_head
|
||||
|
||||
@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
masked_lm_labels=None,
|
||||
):
|
||||
r"""
|
||||
masked_lm_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
Labels for computing the masked language modeling loss.
|
||||
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
|
||||
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
|
||||
in ``[0, ..., config.vocab_size]``
|
||||
|
||||
Returns:
|
||||
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.ElectraConfig`) and inputs:
|
||||
masked_lm_loss (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
Masked language modeling loss.
|
||||
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
||||
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
|
||||
Examples::
|
||||
|
||||
from transformers import ElectraTokenizer, ElectraForMaskedLM
|
||||
import torch
|
||||
|
||||
tokenizer = ElectraTokenizer.from_pretrained('google/electra-small-generator')
|
||||
model = ElectraForMaskedLM.from_pretrained('google/electra-small-generator')
|
||||
|
||||
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
|
||||
outputs = model(input_ids, masked_lm_labels=input_ids)
|
||||
|
||||
loss, prediction_scores = outputs[:2]
|
||||
|
||||
"""
|
||||
|
||||
generator_hidden_states = self.electra(
|
||||
input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds
|
||||
)
|
||||
generator_sequence_output = generator_hidden_states[0]
|
||||
|
||||
prediction_scores = self.generator_predictions(generator_sequence_output)
|
||||
prediction_scores = self.generator_lm_head(prediction_scores)
|
||||
|
||||
output = (prediction_scores,)
|
||||
|
||||
# Masked language modeling softmax layer
|
||||
if masked_lm_labels is not None:
|
||||
loss_fct = nn.CrossEntropyLoss() # -100 index = padding token
|
||||
loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
|
||||
output = (loss,) + output
|
||||
|
||||
output += generator_hidden_states[1:]
|
||||
|
||||
return output # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
Electra model with a token classification head on top.
|
||||
|
||||
Both the discriminator and generator may be loaded into this model.""",
|
||||
ELECTRA_START_DOCSTRING,
|
||||
)
|
||||
class ElectraForTokenClassification(ElectraPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
self.electra = ElectraModel(config)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
||||
self.init_weights()
|
||||
|
||||
@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
Labels for computing the token classification loss.
|
||||
Indices should be in ``[0, ..., config.num_labels - 1]``.
|
||||
|
||||
Returns:
|
||||
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.ElectraConfig`) and inputs:
|
||||
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided) :
|
||||
Classification loss.
|
||||
scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`)
|
||||
Classification scores (before SoftMax).
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
||||
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
|
||||
Examples::
|
||||
|
||||
from transformers import ElectraTokenizer, ElectraForTokenClassification
|
||||
import torch
|
||||
|
||||
tokenizer = ElectraTokenizer.from_pretrained('google/electra-small-discriminator')
|
||||
model = ElectraForTokenClassification.from_pretrained('google/electra-small-discriminator')
|
||||
|
||||
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
|
||||
labels = torch.tensor([1] * input_ids.size(1)).unsqueeze(0) # Batch size 1
|
||||
outputs = model(input_ids, labels=labels)
|
||||
|
||||
loss, scores = outputs[:2]
|
||||
|
||||
"""
|
||||
|
||||
discriminator_hidden_states = self.electra(
|
||||
input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds
|
||||
)
|
||||
discriminator_sequence_output = discriminator_hidden_states[0]
|
||||
|
||||
discriminator_sequence_output = self.dropout(discriminator_sequence_output)
|
||||
logits = self.classifier(discriminator_sequence_output)
|
||||
|
||||
output = (logits,)
|
||||
|
||||
if labels is not None:
|
||||
loss_fct = nn.CrossEntropyLoss()
|
||||
# Only keep active parts of the loss
|
||||
if attention_mask is not None:
|
||||
active_loss = attention_mask.view(-1) == 1
|
||||
active_logits = logits.view(-1, self.config.num_labels)[active_loss]
|
||||
active_labels = labels.view(-1)[active_loss]
|
||||
loss = loss_fct(active_logits, active_labels)
|
||||
else:
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
|
||||
output = (loss,) + output
|
||||
|
||||
output += discriminator_hidden_states[1:]
|
||||
|
||||
return output # (loss), scores, (hidden_states), (attentions)
|
615
src/transformers/modeling_tf_electra.py
Normal file
615
src/transformers/modeling_tf_electra.py
Normal file
@ -0,0 +1,615 @@
|
||||
import logging
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from transformers import ElectraConfig
|
||||
|
||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||
from .modeling_tf_bert import ACT2FN, TFBertEncoder, TFBertPreTrainedModel
|
||||
from .modeling_tf_utils import get_initializer, shape_list
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
"google/electra-small-generator": "https://s3.amazonaws.com/models.huggingface.co/bert/google/electra-small-generator/tf_model.h5",
|
||||
"google/electra-base-generator": "https://s3.amazonaws.com/models.huggingface.co/bert/google/electra-base-generator/tf_model.h5",
|
||||
"google/electra-large-generator": "https://s3.amazonaws.com/models.huggingface.co/bert/google/electra-large-generator/tf_model.h5",
|
||||
"google/electra-small-discriminator": "https://s3.amazonaws.com/models.huggingface.co/bert/google/electra-small-discriminator/tf_model.h5",
|
||||
"google/electra-base-discriminator": "https://s3.amazonaws.com/models.huggingface.co/bert/google/electra-base-discriminator/tf_model.h5",
|
||||
"google/electra-large-discriminator": "https://s3.amazonaws.com/models.huggingface.co/bert/google/electra-large-discriminator/tf_model.h5",
|
||||
}
|
||||
|
||||
|
||||
class TFElectraEmbeddings(tf.keras.layers.Layer):
|
||||
"""Construct the embeddings from word, position and token_type embeddings.
|
||||
"""
|
||||
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.vocab_size = config.vocab_size
|
||||
self.embedding_size = config.embedding_size
|
||||
self.initializer_range = config.initializer_range
|
||||
|
||||
self.position_embeddings = tf.keras.layers.Embedding(
|
||||
config.max_position_embeddings,
|
||||
config.embedding_size,
|
||||
embeddings_initializer=get_initializer(self.initializer_range),
|
||||
name="position_embeddings",
|
||||
)
|
||||
self.token_type_embeddings = tf.keras.layers.Embedding(
|
||||
config.type_vocab_size,
|
||||
config.embedding_size,
|
||||
embeddings_initializer=get_initializer(self.initializer_range),
|
||||
name="token_type_embeddings",
|
||||
)
|
||||
|
||||
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
||||
# any TensorFlow checkpoint file
|
||||
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
|
||||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def build(self, input_shape):
|
||||
"""Build shared word embedding layer """
|
||||
with tf.name_scope("word_embeddings"):
|
||||
# Create and initialize weights. The random normal initializer was chosen
|
||||
# arbitrarily, and works well.
|
||||
self.word_embeddings = self.add_weight(
|
||||
"weight",
|
||||
shape=[self.vocab_size, self.embedding_size],
|
||||
initializer=get_initializer(self.initializer_range),
|
||||
)
|
||||
super().build(input_shape)
|
||||
|
||||
def call(self, inputs, mode="embedding", training=False):
|
||||
"""Get token embeddings of inputs.
|
||||
Args:
|
||||
inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
|
||||
mode: string, a valid value is one of "embedding" and "linear".
|
||||
Returns:
|
||||
outputs: (1) If mode == "embedding", output embedding tensor, float32 with
|
||||
shape [batch_size, length, embedding_size]; (2) mode == "linear", output
|
||||
linear tensor, float32 with shape [batch_size, length, vocab_size].
|
||||
Raises:
|
||||
ValueError: if mode is not valid.
|
||||
|
||||
Shared weights logic adapted from
|
||||
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
|
||||
"""
|
||||
if mode == "embedding":
|
||||
return self._embedding(inputs, training=training)
|
||||
elif mode == "linear":
|
||||
return self._linear(inputs)
|
||||
else:
|
||||
raise ValueError("mode {} is not valid.".format(mode))
|
||||
|
||||
def _embedding(self, inputs, training=False):
|
||||
"""Applies embedding based on inputs tensor."""
|
||||
input_ids, position_ids, token_type_ids, inputs_embeds = inputs
|
||||
|
||||
if input_ids is not None:
|
||||
input_shape = shape_list(input_ids)
|
||||
else:
|
||||
input_shape = shape_list(inputs_embeds)[:-1]
|
||||
|
||||
seq_length = input_shape[1]
|
||||
if position_ids is None:
|
||||
position_ids = tf.range(seq_length, dtype=tf.int32)[tf.newaxis, :]
|
||||
if token_type_ids is None:
|
||||
token_type_ids = tf.fill(input_shape, 0)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = tf.gather(self.word_embeddings, input_ids)
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||
|
||||
embeddings = inputs_embeds + position_embeddings + token_type_embeddings
|
||||
embeddings = self.LayerNorm(embeddings)
|
||||
embeddings = self.dropout(embeddings, training=training)
|
||||
return embeddings
|
||||
|
||||
def _linear(self, inputs):
|
||||
"""Computes logits by running inputs through a linear layer.
|
||||
Args:
|
||||
inputs: A float32 tensor with shape [batch_size, length, hidden_size]
|
||||
Returns:
|
||||
float32 tensor with shape [batch_size, length, vocab_size].
|
||||
"""
|
||||
batch_size = shape_list(inputs)[0]
|
||||
length = shape_list(inputs)[1]
|
||||
|
||||
x = tf.reshape(inputs, [-1, self.embedding_size])
|
||||
logits = tf.matmul(x, self.word_embeddings, transpose_b=True)
|
||||
|
||||
return tf.reshape(logits, [batch_size, length, self.vocab_size])
|
||||
|
||||
|
||||
class TFElectraDiscriminatorPredictions(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.dense = tf.keras.layers.Dense(config.hidden_size, name="dense")
|
||||
self.dense_prediction = tf.keras.layers.Dense(1, name="dense_prediction")
|
||||
self.config = config
|
||||
|
||||
def call(self, discriminator_hidden_states, training=False):
|
||||
hidden_states = self.dense(discriminator_hidden_states)
|
||||
hidden_states = ACT2FN[self.config.hidden_act](hidden_states)
|
||||
logits = tf.squeeze(self.dense_prediction(hidden_states))
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
class TFElectraGeneratorPredictions(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
|
||||
self.dense = tf.keras.layers.Dense(config.embedding_size, name="dense")
|
||||
|
||||
def call(self, generator_hidden_states, training=False):
|
||||
hidden_states = self.dense(generator_hidden_states)
|
||||
hidden_states = ACT2FN["gelu"](hidden_states)
|
||||
hidden_states = self.LayerNorm(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class TFElectraPreTrainedModel(TFBertPreTrainedModel):
|
||||
|
||||
config_class = ElectraConfig
|
||||
pretrained_model_archive_map = TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
base_model_prefix = "electra"
|
||||
|
||||
def get_extended_attention_mask(self, attention_mask, input_shape):
|
||||
if attention_mask is None:
|
||||
attention_mask = tf.fill(input_shape, 1)
|
||||
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
# positions we want to attend and -10000.0 for masked positions.
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
|
||||
extended_attention_mask = tf.cast(extended_attention_mask, tf.float32)
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
|
||||
return extended_attention_mask
|
||||
|
||||
def get_head_mask(self, head_mask):
|
||||
if head_mask is not None:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
head_mask = [None] * self.config.num_hidden_layers
|
||||
|
||||
return head_mask
|
||||
|
||||
|
||||
class TFElectraMainLayer(TFElectraPreTrainedModel):
|
||||
|
||||
config_class = ElectraConfig
|
||||
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(config, **kwargs)
|
||||
self.embeddings = TFElectraEmbeddings(config, name="embeddings")
|
||||
|
||||
if config.embedding_size != config.hidden_size:
|
||||
self.embeddings_project = tf.keras.layers.Dense(config.hidden_size, name="embeddings_project")
|
||||
self.encoder = TFBertEncoder(config, name="encoder")
|
||||
self.config = config
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
raise NotImplementedError
|
||||
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
""" Prunes heads of the model.
|
||||
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
||||
See base class PreTrainedModel
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def call(
|
||||
self,
|
||||
inputs,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
training=False,
|
||||
):
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
|
||||
position_ids = inputs[3] if len(inputs) > 3 else position_ids
|
||||
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
||||
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
|
||||
assert len(inputs) <= 6, "Too many inputs."
|
||||
elif isinstance(inputs, dict):
|
||||
input_ids = inputs.get("input_ids")
|
||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
||||
position_ids = inputs.get("position_ids", position_ids)
|
||||
head_mask = inputs.get("head_mask", head_mask)
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
assert len(inputs) <= 6, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = shape_list(input_ids)
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = shape_list(inputs_embeds)[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = tf.fill(input_shape, 1)
|
||||
if token_type_ids is None:
|
||||
token_type_ids = tf.fill(input_shape, 0)
|
||||
|
||||
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
head_mask = self.get_head_mask(head_mask)
|
||||
|
||||
hidden_states = self.embeddings([input_ids, position_ids, token_type_ids, inputs_embeds], training=training)
|
||||
|
||||
if hasattr(self, "embeddings_project"):
|
||||
hidden_states = self.embeddings_project(hidden_states, training=training)
|
||||
|
||||
hidden_states = self.encoder([hidden_states, extended_attention_mask, head_mask], training=training)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
ELECTRA_START_DOCSTRING = r"""
|
||||
This model is a `tf.keras.Model <https://www.tensorflow.org/api_docs/python/tf/keras/Model>`__ sub-class.
|
||||
Use it as a regular TF 2.0 Keras Model and
|
||||
refer to the TF 2.0 documentation for all matter related to general usage and behavior.
|
||||
|
||||
.. note::
|
||||
|
||||
TF 2.0 models accepts two formats as inputs:
|
||||
|
||||
- having all inputs as keyword arguments (like PyTorch models), or
|
||||
- having all inputs as a list, tuple or dict in the first positional arguments.
|
||||
|
||||
This second option is useful when using :obj:`tf.keras.Model.fit()` method which currently requires having
|
||||
all the tensors in the first argument of the model call function: :obj:`model(inputs)`.
|
||||
|
||||
If you choose this second option, there are three possibilities you can use to gather all the input Tensors
|
||||
in the first positional argument :
|
||||
|
||||
- a single Tensor with input_ids only and nothing else: :obj:`model(inputs_ids)`
|
||||
- a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
|
||||
:obj:`model([input_ids, attention_mask])` or :obj:`model([input_ids, attention_mask, token_type_ids])`
|
||||
- a dictionary with one or several input Tensors associated to the input names given in the docstring:
|
||||
:obj:`model({'input_ids': input_ids, 'token_type_ids': token_type_ids})`
|
||||
|
||||
Parameters:
|
||||
config (:class:`~transformers.ElectraConfig`): Model configuration class with all the parameters of the model.
|
||||
Initializing with a config file does not load the weights associated with the model, only the configuration.
|
||||
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
|
||||
"""
|
||||
|
||||
ELECTRA_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
|
||||
Indices can be obtained using :class:`transformers.ElectraTokenizer`.
|
||||
See :func:`transformers.PreTrainedTokenizer.encode` and
|
||||
:func:`transformers.PreTrainedTokenizer.encode_plus` for details.
|
||||
|
||||
`What are input IDs? <../glossary.html#input-ids>`__
|
||||
attention_mask (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
Mask to avoid performing attention on padding token indices.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
head_mask (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
|
||||
Mask to nullify selected heads of the self-attention modules.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
:obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**.
|
||||
inputs_embeds (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, embedding_dim)`, `optional`, defaults to :obj:`None`):
|
||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||
than the model's internal embedding lookup matrix.
|
||||
training (:obj:`boolean`, `optional`, defaults to :obj:`False`):
|
||||
Whether to activate dropout modules (if set to :obj:`True`) during training or to de-activate them
|
||||
(if set to :obj:`False`) for evaluation.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare Electra Model transformer outputting raw hidden-states without any specific head on top. Identical to "
|
||||
"the BERT model except that it uses an additional linear layer between the embedding layer and the encoder if the "
|
||||
"hidden size and embedding size are different."
|
||||
""
|
||||
"Both the generator and discriminator checkpoints may be loaded into this model.",
|
||||
ELECTRA_START_DOCSTRING,
|
||||
)
|
||||
class TFElectraModel(TFElectraPreTrainedModel):
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
self.electra = TFElectraMainLayer(config, name="electra")
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.electra.embeddings
|
||||
|
||||
@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING)
|
||||
def call(self, inputs, **kwargs):
|
||||
r"""
|
||||
Returns:
|
||||
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.ElectraConfig`) and inputs:
|
||||
last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
||||
Sequence of hidden-states at the output of the last layer of the model.
|
||||
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
|
||||
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
|
||||
tuple of :obj:`tf.Tensor` (one for each layer) of shape
|
||||
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||
|
||||
Examples::
|
||||
|
||||
import tensorflow as tf
|
||||
from transformers import ElectraTokenizer, TFElectraModel
|
||||
|
||||
tokenizer = ElectraTokenizer.from_pretrained('google/electra-small-discriminator')
|
||||
model = TFElectraModel.from_pretrained('google/electra-small-discriminator')
|
||||
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
|
||||
outputs = model(input_ids)
|
||||
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
||||
"""
|
||||
outputs = self.electra(inputs, **kwargs)
|
||||
return outputs
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
Electra model with a binary classification head on top as used during pre-training for identifying generated
|
||||
tokens.
|
||||
|
||||
Even though both the discriminator and generator may be loaded into this model, the discriminator is
|
||||
the only model of the two to have the correct classification head to be used for this model.""",
|
||||
ELECTRA_START_DOCSTRING,
|
||||
)
|
||||
class TFElectraForPreTraining(TFElectraPreTrainedModel):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(config, **kwargs)
|
||||
|
||||
self.electra = TFElectraMainLayer(config, name="electra")
|
||||
self.discriminator_predictions = TFElectraDiscriminatorPredictions(config, name="discriminator_predictions")
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.electra.embeddings
|
||||
|
||||
@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
training=False,
|
||||
):
|
||||
r"""
|
||||
Returns:
|
||||
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.ElectraConfig`) and inputs:
|
||||
scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`):
|
||||
Prediction scores of the head (scores for each token before SoftMax).
|
||||
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
|
||||
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
|
||||
tuple of :obj:`tf.Tensor` (one for each layer) of shape
|
||||
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||
|
||||
Examples::
|
||||
|
||||
import tensorflow as tf
|
||||
from transformers import ElectraTokenizer, TFElectraForPreTraining
|
||||
|
||||
tokenizer = ElectraTokenizer.from_pretrained('google/electra-small-discriminator')
|
||||
model = TFElectraForPreTraining.from_pretrained('google/electra-small-discriminator')
|
||||
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
|
||||
outputs = model(input_ids)
|
||||
scores = outputs[0]
|
||||
"""
|
||||
|
||||
discriminator_hidden_states = self.electra(
|
||||
input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, training=training
|
||||
)
|
||||
discriminator_sequence_output = discriminator_hidden_states[0]
|
||||
logits = self.discriminator_predictions(discriminator_sequence_output)
|
||||
output = (logits,)
|
||||
output += discriminator_hidden_states[1:]
|
||||
|
||||
return output # (loss), scores, (hidden_states), (attentions)
|
||||
|
||||
|
||||
class TFElectraMaskedLMHead(tf.keras.layers.Layer):
|
||||
def __init__(self, config, input_embeddings, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.vocab_size = config.vocab_size
|
||||
self.input_embeddings = input_embeddings
|
||||
|
||||
def build(self, input_shape):
|
||||
self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias")
|
||||
super().build(input_shape)
|
||||
|
||||
def call(self, hidden_states, training=False):
|
||||
hidden_states = self.input_embeddings(hidden_states, mode="linear")
|
||||
hidden_states = hidden_states + self.bias
|
||||
return hidden_states
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
Electra model with a language modeling head on top.
|
||||
|
||||
Even though both the discriminator and generator may be loaded into this model, the generator is
|
||||
the only model of the two to have been trained for the masked language modeling task.""",
|
||||
ELECTRA_START_DOCSTRING,
|
||||
)
|
||||
class TFElectraForMaskedLM(TFElectraPreTrainedModel):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(config, **kwargs)
|
||||
|
||||
self.vocab_size = config.vocab_size
|
||||
self.electra = TFElectraMainLayer(config, name="electra")
|
||||
self.generator_predictions = TFElectraGeneratorPredictions(config, name="generator_predictions")
|
||||
if isinstance(config.hidden_act, str):
|
||||
self.activation = ACT2FN[config.hidden_act]
|
||||
else:
|
||||
self.activation = config.hidden_act
|
||||
self.generator_lm_head = TFElectraMaskedLMHead(config, self.electra.embeddings, name="generator_lm_head")
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.electra.embeddings
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.generator_lm_head
|
||||
|
||||
@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
training=False,
|
||||
):
|
||||
r"""
|
||||
Returns:
|
||||
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.ElectraConfig`) and inputs:
|
||||
prediction_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
|
||||
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
|
||||
tuple of :obj:`tf.Tensor` (one for each layer) of shape
|
||||
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||
|
||||
Examples::
|
||||
|
||||
import tensorflow as tf
|
||||
from transformers import ElectraTokenizer, TFElectraForMaskedLM
|
||||
|
||||
tokenizer = ElectraTokenizer.from_pretrained('google/electra-small-generator')
|
||||
model = TFElectraForMaskedLM.from_pretrained('google/electra-small-generator')
|
||||
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
|
||||
outputs = model(input_ids)
|
||||
prediction_scores = outputs[0]
|
||||
|
||||
"""
|
||||
|
||||
generator_hidden_states = self.electra(
|
||||
input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, training=training
|
||||
)
|
||||
generator_sequence_output = generator_hidden_states[0]
|
||||
prediction_scores = self.generator_predictions(generator_sequence_output, training=training)
|
||||
prediction_scores = self.generator_lm_head(prediction_scores, training=training)
|
||||
output = (prediction_scores,)
|
||||
output += generator_hidden_states[1:]
|
||||
|
||||
return output # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
Electra model with a token classification head on top.
|
||||
|
||||
Both the discriminator and generator may be loaded into this model.""",
|
||||
ELECTRA_START_DOCSTRING,
|
||||
)
|
||||
class TFElectraForTokenClassification(TFElectraPreTrainedModel):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(config, **kwargs)
|
||||
|
||||
self.electra = TFElectraMainLayer(config, name="electra")
|
||||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = tf.keras.layers.Dense(config.num_labels, name="classifier")
|
||||
|
||||
@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
training=False,
|
||||
):
|
||||
r"""
|
||||
Returns:
|
||||
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.ElectraConfig`) and inputs:
|
||||
scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`):
|
||||
Classification scores (before SoftMax).
|
||||
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
|
||||
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
|
||||
tuple of :obj:`tf.Tensor` (one for each layer) of shape
|
||||
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||
|
||||
Examples::
|
||||
|
||||
import tensorflow as tf
|
||||
from transformers import ElectraTokenizer, TFElectraForTokenClassification
|
||||
|
||||
tokenizer = ElectraTokenizer.from_pretrained('google/electra-small-discriminator')
|
||||
model = TFElectraForTokenClassification.from_pretrained('google/electra-small-discriminator')
|
||||
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
|
||||
outputs = model(input_ids)
|
||||
scores = outputs[0]
|
||||
"""
|
||||
|
||||
discriminator_hidden_states = self.electra(
|
||||
input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, training=training
|
||||
)
|
||||
discriminator_sequence_output = discriminator_hidden_states[0]
|
||||
discriminator_sequence_output = self.dropout(discriminator_sequence_output)
|
||||
logits = self.classifier(discriminator_sequence_output)
|
||||
output = (logits,)
|
||||
output += discriminator_hidden_states[1:]
|
||||
|
||||
return output # (loss), scores, (hidden_states), (attentions)
|
@ -26,6 +26,7 @@ from .configuration_auto import (
|
||||
CamembertConfig,
|
||||
CTRLConfig,
|
||||
DistilBertConfig,
|
||||
ElectraConfig,
|
||||
FlaubertConfig,
|
||||
GPT2Config,
|
||||
OpenAIGPTConfig,
|
||||
@ -44,6 +45,7 @@ from .tokenization_bert_japanese import BertJapaneseTokenizer
|
||||
from .tokenization_camembert import CamembertTokenizer
|
||||
from .tokenization_ctrl import CTRLTokenizer
|
||||
from .tokenization_distilbert import DistilBertTokenizer, DistilBertTokenizerFast
|
||||
from .tokenization_electra import ElectraTokenizer, ElectraTokenizerFast
|
||||
from .tokenization_flaubert import FlaubertTokenizer
|
||||
from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast
|
||||
from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast
|
||||
@ -67,6 +69,7 @@ TOKENIZER_MAPPING = OrderedDict(
|
||||
(XLMRobertaConfig, (XLMRobertaTokenizer, None)),
|
||||
(BartConfig, (BartTokenizer, None)),
|
||||
(RobertaConfig, (RobertaTokenizer, RobertaTokenizerFast)),
|
||||
(ElectraConfig, (ElectraTokenizer, ElectraTokenizerFast)),
|
||||
(BertConfig, (BertTokenizer, BertTokenizerFast)),
|
||||
(OpenAIGPTConfig, (OpenAIGPTTokenizer, OpenAIGPTTokenizerFast)),
|
||||
(GPT2Config, (GPT2Tokenizer, GPT2TokenizerFast)),
|
||||
@ -104,6 +107,7 @@ class AutoTokenizer:
|
||||
- contains `xlnet`: XLNetTokenizer (XLNet model)
|
||||
- contains `xlm`: XLMTokenizer (XLM model)
|
||||
- contains `ctrl`: CTRLTokenizer (Salesforce CTRL model)
|
||||
- contains `electra`: ElectraTokenizer (Google ELECTRA model)
|
||||
|
||||
This class cannot be instantiated using `__init__()` (throw an error).
|
||||
"""
|
||||
@ -135,6 +139,7 @@ class AutoTokenizer:
|
||||
- contains `xlnet`: XLNetTokenizer (XLNet model)
|
||||
- contains `xlm`: XLMTokenizer (XLM model)
|
||||
- contains `ctrl`: CTRLTokenizer (Salesforce CTRL model)
|
||||
- contains `electra`: ElectraTokenizer (Google ELECTRA model)
|
||||
|
||||
Params:
|
||||
pretrained_model_name_or_path: either:
|
||||
|
80
src/transformers/tokenization_electra.py
Normal file
80
src/transformers/tokenization_electra.py
Normal file
@ -0,0 +1,80 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The Google AI Team, Stanford University and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .tokenization_bert import BertTokenizer, BertTokenizerFast
|
||||
|
||||
|
||||
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
|
||||
|
||||
PRETRAINED_VOCAB_FILES_MAP = {
|
||||
"vocab_file": {
|
||||
"google/electra-small-generator": "https://s3.amazonaws.com/models.huggingface.co/bert/google/electra-small-generator/vocab.txt",
|
||||
"google/electra-base-generator": "https://s3.amazonaws.com/models.huggingface.co/bert/google/electra-base-generator/vocab.txt",
|
||||
"google/electra-large-generator": "https://s3.amazonaws.com/models.huggingface.co/bert/google/electra-large-generator/vocab.txt",
|
||||
"google/electra-small-discriminator": "https://s3.amazonaws.com/models.huggingface.co/bert/google/electra-small-discriminator/vocab.txt",
|
||||
"google/electra-base-discriminator": "https://s3.amazonaws.com/models.huggingface.co/bert/google/electra-base-discriminator/vocab.txt",
|
||||
"google/electra-large-discriminator": "https://s3.amazonaws.com/models.huggingface.co/bert/google/electra-large-discriminator/vocab.txt",
|
||||
}
|
||||
}
|
||||
|
||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||
"google/electra-small-generator": 512,
|
||||
"google/electra-base-generator": 512,
|
||||
"google/electra-large-generator": 512,
|
||||
"google/electra-small-discriminator": 512,
|
||||
"google/electra-base-discriminator": 512,
|
||||
"google/electra-large-discriminator": 512,
|
||||
}
|
||||
|
||||
|
||||
PRETRAINED_INIT_CONFIGURATION = {
|
||||
"google/electra-small-generator": {"do_lower_case": True},
|
||||
"google/electra-base-generator": {"do_lower_case": True},
|
||||
"google/electra-large-generator": {"do_lower_case": True},
|
||||
"google/electra-small-discriminator": {"do_lower_case": True},
|
||||
"google/electra-base-discriminator": {"do_lower_case": True},
|
||||
"google/electra-large-discriminator": {"do_lower_case": True},
|
||||
}
|
||||
|
||||
|
||||
class ElectraTokenizer(BertTokenizer):
|
||||
r"""
|
||||
Constructs an Electra tokenizer.
|
||||
:class:`~transformers.ElectraTokenizer` is identical to :class:`~transformers.BertTokenizer` and runs end-to-end
|
||||
tokenization: punctuation splitting + wordpiece.
|
||||
|
||||
Refer to superclass :class:`~transformers.BertTokenizer` for usage examples and documentation concerning
|
||||
parameters.
|
||||
"""
|
||||
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
|
||||
|
||||
|
||||
class ElectraTokenizerFast(BertTokenizerFast):
|
||||
r"""
|
||||
Constructs an Electra Fast tokenizer.
|
||||
:class:`~transformers.ElectraTokenizerFast` is identical to :class:`~transformers.BertTokenizerFast` and runs end-to-end
|
||||
tokenization: punctuation splitting + wordpiece.
|
||||
|
||||
Refer to superclass :class:`~transformers.BertTokenizerFast` for usage examples and documentation concerning
|
||||
parameters.
|
||||
"""
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
|
287
tests/test_modeling_electra.py
Normal file
287
tests/test_modeling_electra.py
Normal file
@ -0,0 +1,287 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
from .utils import CACHE_DIR, require_torch, slow, torch_device
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from transformers import (
|
||||
ElectraConfig,
|
||||
ElectraModel,
|
||||
ElectraForMaskedLM,
|
||||
ElectraForTokenClassification,
|
||||
ElectraForPreTraining,
|
||||
)
|
||||
from transformers.modeling_electra import ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
|
||||
@require_torch
|
||||
class ElectraModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (
|
||||
(ElectraModel, ElectraForMaskedLM, ElectraForTokenClassification,) if is_torch_available() else ()
|
||||
)
|
||||
|
||||
class ElectraModelTester(object):
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_input_mask=True,
|
||||
use_token_type_ids=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=16,
|
||||
type_sequence_label_size=2,
|
||||
initializer_range=0.02,
|
||||
num_labels=3,
|
||||
num_choices=4,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_input_mask = use_input_mask
|
||||
self.use_token_type_ids = use_token_type_ids
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.initializer_range = initializer_range
|
||||
self.num_labels = num_labels
|
||||
self.num_choices = num_choices
|
||||
self.scope = scope
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
||||
|
||||
sequence_labels = None
|
||||
token_labels = None
|
||||
choice_labels = None
|
||||
if self.use_labels:
|
||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
||||
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
||||
fake_token_labels = ids_tensor([self.batch_size, self.seq_length], 1)
|
||||
|
||||
config = ElectraConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
intermediate_size=self.intermediate_size,
|
||||
hidden_act=self.hidden_act,
|
||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
type_vocab_size=self.type_vocab_size,
|
||||
is_decoder=False,
|
||||
initializer_range=self.initializer_range,
|
||||
)
|
||||
|
||||
return (
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
fake_token_labels,
|
||||
)
|
||||
|
||||
def check_loss_output(self, result):
|
||||
self.parent.assertListEqual(list(result["loss"].size()), [])
|
||||
|
||||
def create_and_check_electra_model(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
fake_token_labels,
|
||||
):
|
||||
model = ElectraModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
(sequence_output,) = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
||||
(sequence_output,) = model(input_ids, token_type_ids=token_type_ids)
|
||||
(sequence_output,) = model(input_ids)
|
||||
|
||||
result = {
|
||||
"sequence_output": sequence_output,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
)
|
||||
|
||||
def create_and_check_electra_for_masked_lm(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
fake_token_labels,
|
||||
):
|
||||
model = ElectraForMaskedLM(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, prediction_scores = model(
|
||||
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, masked_lm_labels=token_labels
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"prediction_scores": prediction_scores,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size]
|
||||
)
|
||||
self.check_loss_output(result)
|
||||
|
||||
def create_and_check_electra_for_token_classification(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
fake_token_labels,
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
model = ElectraForTokenClassification(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, logits = model(
|
||||
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"logits": logits,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels]
|
||||
)
|
||||
self.check_loss_output(result)
|
||||
|
||||
def create_and_check_electra_for_pretraining(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
fake_token_labels,
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
model = ElectraForPreTraining(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, logits = model(
|
||||
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=fake_token_labels
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"logits": logits,
|
||||
}
|
||||
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length])
|
||||
self.check_loss_output(result)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
fake_token_labels,
|
||||
) = config_and_inputs
|
||||
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = ElectraModelTest.ElectraModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=ElectraConfig, hidden_size=37)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_electra_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_electra_model(*config_and_inputs)
|
||||
|
||||
def test_for_masked_lm(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_electra_for_masked_lm(*config_and_inputs)
|
||||
|
||||
def test_for_token_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_electra_for_token_classification(*config_and_inputs)
|
||||
|
||||
def test_for_pre_training(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_electra_for_pretraining(*config_and_inputs)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in list(ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = ElectraModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
227
tests/test_modeling_tf_electra.py
Normal file
227
tests/test_modeling_tf_electra.py
Normal file
@ -0,0 +1,227 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import ElectraConfig, is_tf_available
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
||||
from .utils import CACHE_DIR, require_tf, slow
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
from transformers.modeling_tf_electra import (
|
||||
TFElectraModel,
|
||||
TFElectraForMaskedLM,
|
||||
TFElectraForPreTraining,
|
||||
TFElectraForTokenClassification,
|
||||
)
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFElectraModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (
|
||||
(TFElectraModel, TFElectraForMaskedLM, TFElectraForPreTraining, TFElectraForTokenClassification,)
|
||||
if is_tf_available()
|
||||
else ()
|
||||
)
|
||||
|
||||
class TFElectraModelTester(object):
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_input_mask=True,
|
||||
use_token_type_ids=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=16,
|
||||
type_sequence_label_size=2,
|
||||
initializer_range=0.02,
|
||||
num_labels=3,
|
||||
num_choices=4,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_input_mask = use_input_mask
|
||||
self.use_token_type_ids = use_token_type_ids
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.initializer_range = initializer_range
|
||||
self.num_labels = num_labels
|
||||
self.num_choices = num_choices
|
||||
self.scope = scope
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
||||
|
||||
sequence_labels = None
|
||||
token_labels = None
|
||||
choice_labels = None
|
||||
if self.use_labels:
|
||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
||||
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
||||
|
||||
config = ElectraConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
intermediate_size=self.intermediate_size,
|
||||
hidden_act=self.hidden_act,
|
||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
type_vocab_size=self.type_vocab_size,
|
||||
initializer_range=self.initializer_range,
|
||||
)
|
||||
|
||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
||||
def create_and_check_electra_model(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = TFElectraModel(config=config)
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||
(sequence_output,) = model(inputs)
|
||||
|
||||
inputs = [input_ids, input_mask]
|
||||
(sequence_output,) = model(inputs)
|
||||
|
||||
(sequence_output,) = model(input_ids)
|
||||
|
||||
result = {
|
||||
"sequence_output": sequence_output.numpy(),
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["sequence_output"].shape), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
)
|
||||
|
||||
def create_and_check_electra_for_masked_lm(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = TFElectraForMaskedLM(config=config)
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||
(prediction_scores,) = model(inputs)
|
||||
result = {
|
||||
"prediction_scores": prediction_scores.numpy(),
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["prediction_scores"].shape), [self.batch_size, self.seq_length, self.vocab_size]
|
||||
)
|
||||
|
||||
def create_and_check_electra_for_pretraining(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = TFElectraForPreTraining(config=config)
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||
(prediction_scores,) = model(inputs)
|
||||
result = {
|
||||
"prediction_scores": prediction_scores.numpy(),
|
||||
}
|
||||
self.parent.assertListEqual(list(result["prediction_scores"].shape), [self.batch_size, self.seq_length])
|
||||
|
||||
def create_and_check_electra_for_token_classification(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
model = TFElectraForTokenClassification(config=config)
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||
(logits,) = model(inputs)
|
||||
result = {
|
||||
"logits": logits.numpy(),
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["logits"].shape), [self.batch_size, self.seq_length, self.num_labels]
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
) = config_and_inputs
|
||||
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFElectraModelTest.TFElectraModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=ElectraConfig, hidden_size=37)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_electra_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_electra_model(*config_and_inputs)
|
||||
|
||||
def test_for_masked_lm(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_electra_for_masked_lm(*config_and_inputs)
|
||||
|
||||
def test_for_pretraining(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_electra_for_pretraining(*config_and_inputs)
|
||||
|
||||
def test_for_token_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_electra_for_token_classification(*config_and_inputs)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
# for model_name in list(TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
for model_name in ["electra-small-discriminator"]:
|
||||
model = TFElectraModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
Loading…
Reference in New Issue
Block a user