Optional layers (#8961)

* Apply on BERT and ALBERT

* Update TF Bart

* Add input processing to TF BART

* Add input processing for TF CTRL

* Add input processing to TF Distilbert

* Add input processing to TF DPR

* Add input processing to TF Electra

* Add deprecated arguments

* Add input processing to TF XLM

* remove unused imports

* Add input processing to TF Funnel

* Add input processing to TF GPT2

* Add input processing to TF Longformer

* Add input processing to TF Lxmert

* Apply style

* Add input processing to TF Mobilebert

* Add input processing to TF GPT

* Add input processing to TF Roberta

* Add input processing to TF T5

* Add input processing to TF TransfoXL

* Apply style

* Rebase on master

* Fix wrong model name

* Fix BART

* Apply style

* Put the deprecated warnings in the input processing function

* Remove the unused imports

* Raise an error when len(kwargs)>0

* test ModelOutput instead of TFBaseModelOutput

* Address Patrick's comments

* Address Patrick's comments

* Add boolean processing for the inputs

* Take into account the optional layers

* Add missing/unexpected weights in the other models

* Apply style

* rename parameters

* Apply style

* Remove useless

* Remove useless

* Remove useless

* Update num parameters

* Fix tests

* Address Patrick's comment

* Remove useless attribute
This commit is contained in:
Julien Plu 2020-12-08 15:14:09 +01:00 committed by GitHub
parent 9d7d0005b0
commit bf7f79cd57
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 195 additions and 98 deletions

View File

@ -481,18 +481,22 @@ class TFAlbertMLMHead(tf.keras.layers.Layer):
class TFAlbertMainLayer(tf.keras.layers.Layer):
config_class = AlbertConfig
def __init__(self, config, **kwargs):
def __init__(self, config, add_pooling_layer=True, **kwargs):
super().__init__(**kwargs)
self.num_hidden_layers = config.num_hidden_layers
self.config = config
self.embeddings = TFAlbertEmbeddings(config, name="embeddings")
self.encoder = TFAlbertTransformer(config, name="encoder")
self.pooler = tf.keras.layers.Dense(
config.hidden_size,
kernel_initializer=get_initializer(config.initializer_range),
activation="tanh",
name="pooler",
self.pooler = (
tf.keras.layers.Dense(
config.hidden_size,
kernel_initializer=get_initializer(config.initializer_range),
activation="tanh",
name="pooler",
)
if add_pooling_layer
else None
)
def get_input_embeddings(self):
@ -601,7 +605,7 @@ class TFAlbertMainLayer(tf.keras.layers.Layer):
)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output[:, 0])
pooled_output = self.pooler(sequence_output[:, 0]) if self.pooler is not None else None
if not inputs["return_dict"]:
return (
@ -807,6 +811,9 @@ class TFAlbertModel(TFAlbertPreTrainedModel):
ALBERT_START_DOCSTRING,
)
class TFAlbertForPreTraining(TFAlbertPreTrainedModel):
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
_keys_to_ignore_on_load_unexpected = [r"predictions.decoder.weight"]
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
@ -914,13 +921,13 @@ class TFAlbertSOPHead(tf.keras.layers.Layer):
@add_start_docstrings("""Albert Model with a `language modeling` head on top. """, ALBERT_START_DOCSTRING)
class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss):
_keys_to_ignore_on_load_missing = [r"pooler"]
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
_keys_to_ignore_on_load_unexpected = [r"pooler", r"predictions.decoder.weight"]
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.albert = TFAlbertMainLayer(config, name="albert")
self.albert = TFAlbertMainLayer(config, add_pooling_layer=False, name="albert")
self.predictions = TFAlbertMLMHead(config, self.albert.embeddings, name="predictions")
def get_output_embeddings(self):
@ -1007,6 +1014,10 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss)
ALBERT_START_DOCSTRING,
)
class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClassificationLoss):
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
_keys_to_ignore_on_load_unexpected = [r"predictions"]
_keys_to_ignore_on_load_missing = [r"dropout"]
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
@ -1099,14 +1110,15 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClass
ALBERT_START_DOCSTRING,
)
class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificationLoss):
_keys_to_ignore_on_load_missing = [r"pooler"]
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
_keys_to_ignore_on_load_unexpected = [r"pooler", r"predictions"]
_keys_to_ignore_on_load_missing = [r"dropout"]
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
self.albert = TFAlbertMainLayer(config, name="albert")
self.albert = TFAlbertMainLayer(config, add_pooling_layer=False, name="albert")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
self.classifier = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
@ -1193,14 +1205,14 @@ class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificat
ALBERT_START_DOCSTRING,
)
class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringLoss):
_keys_to_ignore_on_load_missing = [r"pooler"]
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
_keys_to_ignore_on_load_unexpected = [r"pooler", r"predictions"]
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
self.albert = TFAlbertMainLayer(config, name="albert")
self.albert = TFAlbertMainLayer(config, add_pooling_layer=False, name="albert")
self.qa_outputs = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
)
@ -1301,6 +1313,10 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringL
ALBERT_START_DOCSTRING,
)
class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
_keys_to_ignore_on_load_unexpected = [r"pooler", r"predictions"]
_keys_to_ignore_on_load_missing = [r"dropout"]
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)

View File

@ -876,6 +876,8 @@ class TFSinusoidalPositionalEmbedding(tf.keras.layers.Embedding):
)
@keras_serializable
class TFBartModel(TFPretrainedBartModel):
base_model_prefix = "model"
def __init__(self, config: BartConfig, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, config.pad_token_id, name="model.shared")
@ -1033,10 +1035,6 @@ class TFBartModel(TFPretrainedBartModel):
BART_START_DOCSTRING,
)
class TFBartForConditionalGeneration(TFPretrainedBartModel):
base_model_prefix = "model"
_keys_to_ignore_on_load_missing = [
r"final_logits_bias",
]
_keys_to_ignore_on_load_unexpected = [
r"model.encoder.embed_tokens.weight",
r"model.decoder.embed_tokens.weight",

View File

@ -547,7 +547,7 @@ class TFBertNSPHead(tf.keras.layers.Layer):
class TFBertMainLayer(tf.keras.layers.Layer):
config_class = BertConfig
def __init__(self, config, **kwargs):
def __init__(self, config, add_pooling_layer=True, **kwargs):
super().__init__(**kwargs)
self.config = config
@ -558,7 +558,7 @@ class TFBertMainLayer(tf.keras.layers.Layer):
self.return_dict = config.use_return_dict
self.embeddings = TFBertEmbeddings(config, name="embeddings")
self.encoder = TFBertEncoder(config, name="encoder")
self.pooler = TFBertPooler(config, name="pooler")
self.pooler = TFBertPooler(config, name="pooler") if add_pooling_layer else None
def get_input_embeddings(self):
return self.embeddings
@ -663,7 +663,7 @@ class TFBertMainLayer(tf.keras.layers.Layer):
)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output)
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
if not inputs["return_dict"]:
return (
@ -880,6 +880,9 @@ Bert Model with two heads on top as done during the pretraining:
BERT_START_DOCSTRING,
)
class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss):
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
_keys_to_ignore_on_load_unexpected = [r"cls.predictions.decoder.weight"]
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
@ -976,9 +979,13 @@ class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss):
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"pooler"]
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
_keys_to_ignore_on_load_unexpected = [
r"pooler",
r"cls.seq_relationship",
r"cls.predictions.decoder.weight",
r"nsp___cls",
]
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
@ -989,7 +996,7 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
"bi-directional self-attention."
)
self.bert = TFBertMainLayer(config, name="bert")
self.bert = TFBertMainLayer(config, add_pooling_layer=False, name="bert")
self.mlm = TFBertMLMHead(config, self.bert.embeddings, name="mlm___cls")
def get_output_embeddings(self):
@ -1068,9 +1075,13 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"pooler"]
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
_keys_to_ignore_on_load_unexpected = [
r"pooler",
r"cls.seq_relationship",
r"cls.predictions.decoder.weight",
r"nsp___cls",
]
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
@ -1078,7 +1089,7 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
if not config.is_decoder:
logger.warning("If you want to use `TFBertLMHeadModel` as a standalone, add `is_decoder=True.`")
self.bert = TFBertMainLayer(config, name="bert")
self.bert = TFBertMainLayer(config, add_pooling_layer=False, name="bert")
self.mlm = TFBertMLMHead(config, self.bert.embeddings, name="mlm___cls")
def get_output_embeddings(self):
@ -1165,6 +1176,9 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
BERT_START_DOCSTRING,
)
class TFBertForNextSentencePrediction(TFBertPreTrainedModel, TFNextSentencePredictionLoss):
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
_keys_to_ignore_on_load_unexpected = [r"mlm___cls", r"cls.predictions"]
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
@ -1262,6 +1276,10 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel, TFNextSentencePredi
BERT_START_DOCSTRING,
)
class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassificationLoss):
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
_keys_to_ignore_on_load_unexpected = [r"mlm___cls", r"nsp___cls", r"cls.predictions", r"cls.seq_relationship"]
_keys_to_ignore_on_load_missing = [r"dropout"]
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
@ -1353,6 +1371,10 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
BERT_START_DOCSTRING,
)
class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
_keys_to_ignore_on_load_unexpected = [r"mlm___cls", r"nsp___cls", r"cls.predictions", r"cls.seq_relationship"]
_keys_to_ignore_on_load_missing = [r"dropout"]
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
@ -1477,15 +1499,21 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
BERT_START_DOCSTRING,
)
class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationLoss):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"pooler"]
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
_keys_to_ignore_on_load_unexpected = [
r"pooler",
r"mlm___cls",
r"nsp___cls",
r"cls.predictions",
r"cls.seq_relationship",
]
_keys_to_ignore_on_load_missing = [r"dropout"]
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
self.bert = TFBertMainLayer(config, name="bert")
self.bert = TFBertMainLayer(config, add_pooling_layer=False, name="bert")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
self.classifier = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
@ -1571,15 +1599,20 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
BERT_START_DOCSTRING,
)
class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"pooler"]
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
_keys_to_ignore_on_load_unexpected = [
r"pooler",
r"mlm___cls",
r"nsp___cls",
r"cls.predictions",
r"cls.seq_relationship",
]
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
self.bert = TFBertMainLayer(config, name="bert")
self.bert = TFBertMainLayer(config, add_pooling_layer=False, name="bert")
self.qa_outputs = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
)

View File

@ -468,6 +468,9 @@ class TFElectraPreTrainedModel(TFPreTrainedModel):
config_class = ElectraConfig
base_model_prefix = "electra"
# When the model is loaded from a PT model
_keys_to_ignore_on_load_unexpected = [r"generator_lm_head.weight"]
_keys_to_ignore_on_load_missing = [r"dropout"]
@keras_serializable

View File

@ -1452,7 +1452,6 @@ class TFFunnelForSequenceClassification(TFFunnelPreTrainedModel, TFSequenceClass
return_dict=inputs["return_dict"],
training=inputs["training"],
)
last_hidden_state = outputs[0]
pooled_output = last_hidden_state[:, 0]
logits = self.classifier(pooled_output, training=inputs["training"])
@ -1735,7 +1734,6 @@ class TFFunnelForQuestionAnswering(TFFunnelPreTrainedModel, TFQuestionAnsweringL
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.funnel.return_dict
outputs = self.funnel(
inputs["input_ids"],
inputs["attention_mask"],

View File

@ -413,6 +413,8 @@ class TFGPT2PreTrainedModel(TFPreTrainedModel):
config_class = GPT2Config
base_model_prefix = "transformer"
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
_keys_to_ignore_on_load_unexpected = [r"h.\d+.attn.bias"]
@dataclass

View File

@ -1566,7 +1566,7 @@ class TFLongformerEncoder(tf.keras.layers.Layer):
class TFLongformerMainLayer(tf.keras.layers.Layer):
config_class = LongformerConfig
def __init__(self, config, **kwargs):
def __init__(self, config, add_pooling_layer=True, **kwargs):
super().__init__(**kwargs)
if isinstance(config.attention_window, int):
@ -1589,7 +1589,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
self.attention_window = config.attention_window
self.embeddings = TFLongformerEmbeddings(config, name="embeddings")
self.encoder = TFLongformerEncoder(config, name="encoder")
self.pooler = TFLongformerPooler(config, name="pooler")
self.pooler = TFLongformerPooler(config, name="pooler") if add_pooling_layer else None
def get_input_embeddings(self):
return self.embeddings
@ -1710,7 +1710,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
training=inputs["training"],
)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output)
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
# undo padding
if padding_len > 0:
@ -1997,13 +1997,13 @@ class TFLongformerModel(TFLongformerPreTrainedModel):
LONGFORMER_START_DOCSTRING,
)
class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModelingLoss):
_keys_to_ignore_on_load_missing = [r"pooler"]
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.longformer = TFLongformerMainLayer(config, name="longformer")
self.longformer = TFLongformerMainLayer(config, add_pooling_layer=False, name="longformer")
self.lm_head = TFLongformerLMHead(config, self.longformer.embeddings, name="lm_head")
def get_output_embeddings(self):
@ -2091,14 +2091,14 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel
LONGFORMER_START_DOCSTRING,
)
class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAnsweringLoss):
_keys_to_ignore_on_load_missing = [r"pooler"]
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
self.longformer = TFLongformerMainLayer(config, name="longformer")
self.longformer = TFLongformerMainLayer(config, add_pooling_layer=False, name="longformer")
self.qa_outputs = tf.keras.layers.Dense(
config.num_labels,
kernel_initializer=get_initializer(config.initializer_range),
@ -2248,15 +2248,15 @@ class TFLongformerClassificationHead(tf.keras.layers.Layer):
LONGFORMER_START_DOCSTRING,
)
class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSequenceClassificationLoss):
_keys_to_ignore_on_load_missing = [r"pooler"]
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
self.longformer = TFLongformerMainLayer(config, name="longformer")
self.longformer = TFLongformerMainLayer(config, add_pooling_layer=False, name="longformer")
self.classifier = TFLongformerClassificationHead(config, name="classifier")
@add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ -2346,6 +2346,9 @@ class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSeque
LONGFORMER_START_DOCSTRING,
)
class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoiceLoss):
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
_keys_to_ignore_on_load_missing = [r"dropout"]
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
@ -2478,14 +2481,15 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic
LONGFORMER_START_DOCSTRING,
)
class TFLongformerForTokenClassification(TFLongformerPreTrainedModel, TFTokenClassificationLoss):
_keys_to_ignore_on_load_missing = [r"pooler"]
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"dropout"]
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
self.longformer = TFLongformerMainLayer(config=config, name="longformer")
self.longformer = TFLongformerMainLayer(config=config, add_pooling_layer=False, name="longformer")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
self.classifier = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"

View File

@ -754,7 +754,6 @@ class TFLxmertMainLayer(tf.keras.layers.Layer):
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs["visual_pos"] is None or inputs["visual_feats"] is None:
raise ValueError("visual_feats and visual_pos cannot be `None` in LXMERT's `call` method.")

View File

@ -686,7 +686,7 @@ class TFMobileBertMLMHead(tf.keras.layers.Layer):
class TFMobileBertMainLayer(tf.keras.layers.Layer):
config_class = MobileBertConfig
def __init__(self, config, **kwargs):
def __init__(self, config, add_pooling_layer=True, **kwargs):
super().__init__(**kwargs)
self.config = config
@ -697,7 +697,7 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer):
self.embeddings = TFMobileBertEmbeddings(config, name="embeddings")
self.encoder = TFMobileBertEncoder(config, name="encoder")
self.pooler = TFMobileBertPooler(config, name="pooler")
self.pooler = TFMobileBertPooler(config, name="pooler") if add_pooling_layer else None
def get_input_embeddings(self):
return self.embeddings
@ -801,7 +801,7 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer):
)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output)
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
if not inputs["return_dict"]:
return (
@ -1102,13 +1102,18 @@ class TFMobileBertForPreTraining(TFMobileBertPreTrainedModel):
@add_start_docstrings("""MobileBert Model with a `language modeling` head on top. """, MOBILEBERT_START_DOCSTRING)
class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModelingLoss):
_keys_to_ignore_on_load_missing = [r"pooler"]
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
_keys_to_ignore_on_load_unexpected = [
r"pooler",
r"seq_relationship___cls",
r"predictions___cls",
r"cls.seq_relationship",
]
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.mobilebert = TFMobileBertMainLayer(config, name="mobilebert")
self.mobilebert = TFMobileBertMainLayer(config, add_pooling_layer=False, name="mobilebert")
self.mlm = TFMobileBertMLMHead(config, name="mlm___cls")
def get_output_embeddings(self):
@ -1170,7 +1175,6 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel
return_dict=inputs["return_dict"],
training=inputs["training"],
)
sequence_output = outputs[0]
prediction_scores = self.mlm(sequence_output, training=inputs["training"])
@ -1203,6 +1207,9 @@ class TFMobileBertOnlyNSPHead(tf.keras.layers.Layer):
MOBILEBERT_START_DOCSTRING,
)
class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel, TFNextSentencePredictionLoss):
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
_keys_to_ignore_on_load_unexpected = [r"predictions___cls", r"cls.predictions"]
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
@ -1300,6 +1307,15 @@ class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel, TFNextS
MOBILEBERT_START_DOCSTRING,
)
class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSequenceClassificationLoss):
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
_keys_to_ignore_on_load_unexpected = [
r"predictions___cls",
r"seq_relationship___cls",
r"cls.predictions",
r"cls.seq_relationship",
]
_keys_to_ignore_on_load_missing = [r"dropout"]
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
@ -1393,14 +1409,20 @@ class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSeque
MOBILEBERT_START_DOCSTRING,
)
class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAnsweringLoss):
_keys_to_ignore_on_load_missing = [r"pooler"]
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
_keys_to_ignore_on_load_unexpected = [
r"pooler",
r"predictions___cls",
r"seq_relationship___cls",
r"cls.predictions",
r"cls.seq_relationship",
]
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
self.mobilebert = TFMobileBertMainLayer(config, name="mobilebert")
self.mobilebert = TFMobileBertMainLayer(config, add_pooling_layer=False, name="mobilebert")
self.qa_outputs = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
)
@ -1501,6 +1523,15 @@ class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAn
MOBILEBERT_START_DOCSTRING,
)
class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoiceLoss):
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
_keys_to_ignore_on_load_unexpected = [
r"predictions___cls",
r"seq_relationship___cls",
r"cls.predictions",
r"cls.seq_relationship",
]
_keys_to_ignore_on_load_missing = [r"dropout"]
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
@ -1628,14 +1659,21 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic
MOBILEBERT_START_DOCSTRING,
)
class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenClassificationLoss):
_keys_to_ignore_on_load_missing = [r"pooler"]
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
_keys_to_ignore_on_load_unexpected = [
r"pooler",
r"predictions___cls",
r"seq_relationship___cls",
r"cls.predictions",
r"cls.seq_relationship",
]
_keys_to_ignore_on_load_missing = [r"dropout"]
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
self.mobilebert = TFMobileBertMainLayer(config, name="mobilebert")
self.mobilebert = TFMobileBertMainLayer(config, add_pooling_layer=False, name="mobilebert")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
self.classifier = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
@ -1696,7 +1734,6 @@ class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenCla
return_dict=return_dict,
training=inputs["training"],
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output, training=inputs["training"])

View File

@ -464,7 +464,7 @@ class TFRobertaEncoder(tf.keras.layers.Layer):
class TFRobertaMainLayer(tf.keras.layers.Layer):
config_class = RobertaConfig
def __init__(self, config, **kwargs):
def __init__(self, config, add_pooling_layer=True, **kwargs):
super().__init__(**kwargs)
self.config = config
@ -474,7 +474,7 @@ class TFRobertaMainLayer(tf.keras.layers.Layer):
self.output_hidden_states = config.output_hidden_states
self.return_dict = config.use_return_dict
self.encoder = TFRobertaEncoder(config, name="encoder")
self.pooler = TFRobertaPooler(config, name="pooler")
self.pooler = TFRobertaPooler(config, name="pooler") if add_pooling_layer else None
# The embeddings must be the last declaration in order to follow the weights order
self.embeddings = TFRobertaEmbeddings(config, name="embeddings")
@ -586,7 +586,7 @@ class TFRobertaMainLayer(tf.keras.layers.Layer):
)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output)
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
if not inputs["return_dict"]:
return (
@ -798,13 +798,13 @@ class TFRobertaLMHead(tf.keras.layers.Layer):
@add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """, ROBERTA_START_DOCSTRING)
class TFRobertaForMaskedLM(TFRobertaPreTrainedModel, TFMaskedLanguageModelingLoss):
_keys_to_ignore_on_load_missing = [r"pooler"]
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
_keys_to_ignore_on_load_unexpected = [r"pooler", r"lm_head.decoder.weight"]
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.roberta = TFRobertaMainLayer(config, name="roberta")
self.roberta = TFRobertaMainLayer(config, add_pooling_layer=False, name="roberta")
self.lm_head = TFRobertaLMHead(config, self.roberta.embeddings, name="lm_head")
def get_output_embeddings(self):
@ -917,14 +917,14 @@ class TFRobertaClassificationHead(tf.keras.layers.Layer):
ROBERTA_START_DOCSTRING,
)
class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceClassificationLoss):
_keys_to_ignore_on_load_missing = [r"pooler"]
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
_keys_to_ignore_on_load_unexpected = [r"pooler", r"lm_head"]
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
self.roberta = TFRobertaMainLayer(config, name="roberta")
self.roberta = TFRobertaMainLayer(config, add_pooling_layer=False, name="roberta")
self.classifier = TFRobertaClassificationHead(config, name="classifier")
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ -983,7 +983,6 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceCla
return_dict=inputs["return_dict"],
training=inputs["training"],
)
sequence_output = outputs[0]
logits = self.classifier(sequence_output, training=inputs["training"])
@ -1009,6 +1008,10 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceCla
ROBERTA_START_DOCSTRING,
)
class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss):
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
_keys_to_ignore_on_load_unexpected = [r"lm_head"]
_keys_to_ignore_on_load_missing = [r"dropout"]
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
@ -1129,14 +1132,15 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss)
ROBERTA_START_DOCSTRING,
)
class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassificationLoss):
_keys_to_ignore_on_load_missing = [r"pooler"]
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
_keys_to_ignore_on_load_unexpected = [r"pooler", r"lm_head"]
_keys_to_ignore_on_load_missing = [r"dropout"]
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
self.roberta = TFRobertaMainLayer(config, name="roberta")
self.roberta = TFRobertaMainLayer(config, add_pooling_layer=False, name="roberta")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
self.classifier = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
@ -1224,14 +1228,14 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific
ROBERTA_START_DOCSTRING,
)
class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnsweringLoss):
_keys_to_ignore_on_load_missing = [r"pooler"]
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
_keys_to_ignore_on_load_unexpected = [r"pooler", r"lm_head"]
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
self.roberta = TFRobertaMainLayer(config, name="roberta")
self.roberta = TFRobertaMainLayer(config, add_pooling_layer=False, name="roberta")
self.qa_outputs = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
)

View File

@ -788,6 +788,8 @@ class TFT5PreTrainedModel(TFPreTrainedModel):
config_class = T5Config
base_model_prefix = "transformer"
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
_keys_to_ignore_on_load_unexpected = [r"decoder\Wblock[\W_0]+layer[\W_1]+EncDecAttention\Wrelative_attention_bias"]
@property
def dummy_inputs(self):

View File

@ -16,6 +16,7 @@
"""
TF 2.0 Transformer XL model.
"""
from dataclasses import dataclass
from typing import List, Optional, Tuple

View File

@ -963,7 +963,7 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
return_dict=inputs["return_dict"],
training=inputs["training"],
)
output = transformer_outputs[0]

View File

@ -1667,6 +1667,7 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
1]``.
"""
inputs = input_processing(
func=self.call,
config=self.config,
@ -1703,7 +1704,6 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
return_dict=inputs["return_dict"],
training=inputs["training"],
)
output = transformer_outputs[0]
logits = self.classifier(output)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)

View File

@ -167,14 +167,14 @@ class TFAutoModelTest(unittest.TestCase):
def test_from_pretrained_identifier(self):
model = TFAutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER)
self.assertIsInstance(model, TFBertForMaskedLM)
self.assertEqual(model.num_parameters(), 14830)
self.assertEqual(model.num_parameters(only_trainable=True), 14830)
self.assertEqual(model.num_parameters(), 14410)
self.assertEqual(model.num_parameters(only_trainable=True), 14410)
def test_from_identifier_from_model_type(self):
model = TFAutoModelWithLMHead.from_pretrained(DUMMY_UNKWOWN_IDENTIFIER)
self.assertIsInstance(model, TFRobertaForMaskedLM)
self.assertEqual(model.num_parameters(), 14830)
self.assertEqual(model.num_parameters(only_trainable=True), 14830)
self.assertEqual(model.num_parameters(), 14410)
self.assertEqual(model.num_parameters(only_trainable=True), 14410)
def test_parents_and_children_in_mappings(self):
# Test that the children are placed before the parents in the mappings, as the `instanceof` will be triggered

View File

@ -335,7 +335,7 @@ class TFBertModelTest(TFModelTesterMixin, unittest.TestCase):
model, output_loading_info = TFBertForTokenClassification.from_pretrained(
"jplu/tiny-tf-bert-random", output_loading_info=True
)
self.assertEqual(sorted(output_loading_info["unexpected_keys"]), ["mlm___cls", "nsp___cls"])
self.assertEqual(sorted(output_loading_info["unexpected_keys"]), [])
for layer in output_loading_info["missing_keys"]:
self.assertTrue(layer.split("_")[0] in ["dropout", "classifier"])

View File

@ -223,8 +223,8 @@ class TFPTAutoModelTest(unittest.TestCase):
def test_from_pretrained_identifier(self):
model = TFAutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER, from_pt=True)
self.assertIsInstance(model, TFBertForMaskedLM)
self.assertEqual(model.num_parameters(), 14830)
self.assertEqual(model.num_parameters(only_trainable=True), 14830)
self.assertEqual(model.num_parameters(), 14410)
self.assertEqual(model.num_parameters(only_trainable=True), 14410)
model = AutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER, from_tf=True)
self.assertIsInstance(model, BertForMaskedLM)
@ -234,8 +234,8 @@ class TFPTAutoModelTest(unittest.TestCase):
def test_from_identifier_from_model_type(self):
model = TFAutoModelWithLMHead.from_pretrained(DUMMY_UNKWOWN_IDENTIFIER, from_pt=True)
self.assertIsInstance(model, TFRobertaForMaskedLM)
self.assertEqual(model.num_parameters(), 14830)
self.assertEqual(model.num_parameters(only_trainable=True), 14830)
self.assertEqual(model.num_parameters(), 14410)
self.assertEqual(model.num_parameters(only_trainable=True), 14410)
model = AutoModelWithLMHead.from_pretrained(DUMMY_UNKWOWN_IDENTIFIER, from_tf=True)
self.assertIsInstance(model, RobertaForMaskedLM)