mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Optional layers (#8961)
* Apply on BERT and ALBERT * Update TF Bart * Add input processing to TF BART * Add input processing for TF CTRL * Add input processing to TF Distilbert * Add input processing to TF DPR * Add input processing to TF Electra * Add deprecated arguments * Add input processing to TF XLM * remove unused imports * Add input processing to TF Funnel * Add input processing to TF GPT2 * Add input processing to TF Longformer * Add input processing to TF Lxmert * Apply style * Add input processing to TF Mobilebert * Add input processing to TF GPT * Add input processing to TF Roberta * Add input processing to TF T5 * Add input processing to TF TransfoXL * Apply style * Rebase on master * Fix wrong model name * Fix BART * Apply style * Put the deprecated warnings in the input processing function * Remove the unused imports * Raise an error when len(kwargs)>0 * test ModelOutput instead of TFBaseModelOutput * Address Patrick's comments * Address Patrick's comments * Add boolean processing for the inputs * Take into account the optional layers * Add missing/unexpected weights in the other models * Apply style * rename parameters * Apply style * Remove useless * Remove useless * Remove useless * Update num parameters * Fix tests * Address Patrick's comment * Remove useless attribute
This commit is contained in:
parent
9d7d0005b0
commit
bf7f79cd57
@ -481,18 +481,22 @@ class TFAlbertMLMHead(tf.keras.layers.Layer):
|
||||
class TFAlbertMainLayer(tf.keras.layers.Layer):
|
||||
config_class = AlbertConfig
|
||||
|
||||
def __init__(self, config, **kwargs):
|
||||
def __init__(self, config, add_pooling_layer=True, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.num_hidden_layers = config.num_hidden_layers
|
||||
self.config = config
|
||||
|
||||
self.embeddings = TFAlbertEmbeddings(config, name="embeddings")
|
||||
self.encoder = TFAlbertTransformer(config, name="encoder")
|
||||
self.pooler = tf.keras.layers.Dense(
|
||||
config.hidden_size,
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
activation="tanh",
|
||||
name="pooler",
|
||||
self.pooler = (
|
||||
tf.keras.layers.Dense(
|
||||
config.hidden_size,
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
activation="tanh",
|
||||
name="pooler",
|
||||
)
|
||||
if add_pooling_layer
|
||||
else None
|
||||
)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
@ -601,7 +605,7 @@ class TFAlbertMainLayer(tf.keras.layers.Layer):
|
||||
)
|
||||
|
||||
sequence_output = encoder_outputs[0]
|
||||
pooled_output = self.pooler(sequence_output[:, 0])
|
||||
pooled_output = self.pooler(sequence_output[:, 0]) if self.pooler is not None else None
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
return (
|
||||
@ -807,6 +811,9 @@ class TFAlbertModel(TFAlbertPreTrainedModel):
|
||||
ALBERT_START_DOCSTRING,
|
||||
)
|
||||
class TFAlbertForPreTraining(TFAlbertPreTrainedModel):
|
||||
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
|
||||
_keys_to_ignore_on_load_unexpected = [r"predictions.decoder.weight"]
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
self.num_labels = config.num_labels
|
||||
@ -914,13 +921,13 @@ class TFAlbertSOPHead(tf.keras.layers.Layer):
|
||||
|
||||
@add_start_docstrings("""Albert Model with a `language modeling` head on top. """, ALBERT_START_DOCSTRING)
|
||||
class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss):
|
||||
|
||||
_keys_to_ignore_on_load_missing = [r"pooler"]
|
||||
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler", r"predictions.decoder.weight"]
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
|
||||
self.albert = TFAlbertMainLayer(config, name="albert")
|
||||
self.albert = TFAlbertMainLayer(config, add_pooling_layer=False, name="albert")
|
||||
self.predictions = TFAlbertMLMHead(config, self.albert.embeddings, name="predictions")
|
||||
|
||||
def get_output_embeddings(self):
|
||||
@ -1007,6 +1014,10 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss)
|
||||
ALBERT_START_DOCSTRING,
|
||||
)
|
||||
class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClassificationLoss):
|
||||
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
|
||||
_keys_to_ignore_on_load_unexpected = [r"predictions"]
|
||||
_keys_to_ignore_on_load_missing = [r"dropout"]
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
self.num_labels = config.num_labels
|
||||
@ -1099,14 +1110,15 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClass
|
||||
ALBERT_START_DOCSTRING,
|
||||
)
|
||||
class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificationLoss):
|
||||
|
||||
_keys_to_ignore_on_load_missing = [r"pooler"]
|
||||
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler", r"predictions"]
|
||||
_keys_to_ignore_on_load_missing = [r"dropout"]
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.albert = TFAlbertMainLayer(config, name="albert")
|
||||
self.albert = TFAlbertMainLayer(config, add_pooling_layer=False, name="albert")
|
||||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = tf.keras.layers.Dense(
|
||||
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
|
||||
@ -1193,14 +1205,14 @@ class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificat
|
||||
ALBERT_START_DOCSTRING,
|
||||
)
|
||||
class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringLoss):
|
||||
|
||||
_keys_to_ignore_on_load_missing = [r"pooler"]
|
||||
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler", r"predictions"]
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.albert = TFAlbertMainLayer(config, name="albert")
|
||||
self.albert = TFAlbertMainLayer(config, add_pooling_layer=False, name="albert")
|
||||
self.qa_outputs = tf.keras.layers.Dense(
|
||||
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
|
||||
)
|
||||
@ -1301,6 +1313,10 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringL
|
||||
ALBERT_START_DOCSTRING,
|
||||
)
|
||||
class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
|
||||
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler", r"predictions"]
|
||||
_keys_to_ignore_on_load_missing = [r"dropout"]
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
|
||||
|
@ -876,6 +876,8 @@ class TFSinusoidalPositionalEmbedding(tf.keras.layers.Embedding):
|
||||
)
|
||||
@keras_serializable
|
||||
class TFBartModel(TFPretrainedBartModel):
|
||||
base_model_prefix = "model"
|
||||
|
||||
def __init__(self, config: BartConfig, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, config.pad_token_id, name="model.shared")
|
||||
@ -1033,10 +1035,6 @@ class TFBartModel(TFPretrainedBartModel):
|
||||
BART_START_DOCSTRING,
|
||||
)
|
||||
class TFBartForConditionalGeneration(TFPretrainedBartModel):
|
||||
base_model_prefix = "model"
|
||||
_keys_to_ignore_on_load_missing = [
|
||||
r"final_logits_bias",
|
||||
]
|
||||
_keys_to_ignore_on_load_unexpected = [
|
||||
r"model.encoder.embed_tokens.weight",
|
||||
r"model.decoder.embed_tokens.weight",
|
||||
|
@ -547,7 +547,7 @@ class TFBertNSPHead(tf.keras.layers.Layer):
|
||||
class TFBertMainLayer(tf.keras.layers.Layer):
|
||||
config_class = BertConfig
|
||||
|
||||
def __init__(self, config, **kwargs):
|
||||
def __init__(self, config, add_pooling_layer=True, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.config = config
|
||||
@ -558,7 +558,7 @@ class TFBertMainLayer(tf.keras.layers.Layer):
|
||||
self.return_dict = config.use_return_dict
|
||||
self.embeddings = TFBertEmbeddings(config, name="embeddings")
|
||||
self.encoder = TFBertEncoder(config, name="encoder")
|
||||
self.pooler = TFBertPooler(config, name="pooler")
|
||||
self.pooler = TFBertPooler(config, name="pooler") if add_pooling_layer else None
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings
|
||||
@ -663,7 +663,7 @@ class TFBertMainLayer(tf.keras.layers.Layer):
|
||||
)
|
||||
|
||||
sequence_output = encoder_outputs[0]
|
||||
pooled_output = self.pooler(sequence_output)
|
||||
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
return (
|
||||
@ -880,6 +880,9 @@ Bert Model with two heads on top as done during the pretraining:
|
||||
BERT_START_DOCSTRING,
|
||||
)
|
||||
class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss):
|
||||
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
|
||||
_keys_to_ignore_on_load_unexpected = [r"cls.predictions.decoder.weight"]
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
|
||||
@ -976,9 +979,13 @@ class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss):
|
||||
|
||||
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
|
||||
class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
|
||||
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||
_keys_to_ignore_on_load_missing = [r"pooler"]
|
||||
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
|
||||
_keys_to_ignore_on_load_unexpected = [
|
||||
r"pooler",
|
||||
r"cls.seq_relationship",
|
||||
r"cls.predictions.decoder.weight",
|
||||
r"nsp___cls",
|
||||
]
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
@ -989,7 +996,7 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
|
||||
"bi-directional self-attention."
|
||||
)
|
||||
|
||||
self.bert = TFBertMainLayer(config, name="bert")
|
||||
self.bert = TFBertMainLayer(config, add_pooling_layer=False, name="bert")
|
||||
self.mlm = TFBertMLMHead(config, self.bert.embeddings, name="mlm___cls")
|
||||
|
||||
def get_output_embeddings(self):
|
||||
@ -1068,9 +1075,13 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
|
||||
|
||||
|
||||
class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||
_keys_to_ignore_on_load_missing = [r"pooler"]
|
||||
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
|
||||
_keys_to_ignore_on_load_unexpected = [
|
||||
r"pooler",
|
||||
r"cls.seq_relationship",
|
||||
r"cls.predictions.decoder.weight",
|
||||
r"nsp___cls",
|
||||
]
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
@ -1078,7 +1089,7 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
if not config.is_decoder:
|
||||
logger.warning("If you want to use `TFBertLMHeadModel` as a standalone, add `is_decoder=True.`")
|
||||
|
||||
self.bert = TFBertMainLayer(config, name="bert")
|
||||
self.bert = TFBertMainLayer(config, add_pooling_layer=False, name="bert")
|
||||
self.mlm = TFBertMLMHead(config, self.bert.embeddings, name="mlm___cls")
|
||||
|
||||
def get_output_embeddings(self):
|
||||
@ -1165,6 +1176,9 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
BERT_START_DOCSTRING,
|
||||
)
|
||||
class TFBertForNextSentencePrediction(TFBertPreTrainedModel, TFNextSentencePredictionLoss):
|
||||
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
|
||||
_keys_to_ignore_on_load_unexpected = [r"mlm___cls", r"cls.predictions"]
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
|
||||
@ -1262,6 +1276,10 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel, TFNextSentencePredi
|
||||
BERT_START_DOCSTRING,
|
||||
)
|
||||
class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassificationLoss):
|
||||
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
|
||||
_keys_to_ignore_on_load_unexpected = [r"mlm___cls", r"nsp___cls", r"cls.predictions", r"cls.seq_relationship"]
|
||||
_keys_to_ignore_on_load_missing = [r"dropout"]
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
|
||||
@ -1353,6 +1371,10 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
|
||||
BERT_START_DOCSTRING,
|
||||
)
|
||||
class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
|
||||
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
|
||||
_keys_to_ignore_on_load_unexpected = [r"mlm___cls", r"nsp___cls", r"cls.predictions", r"cls.seq_relationship"]
|
||||
_keys_to_ignore_on_load_missing = [r"dropout"]
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
|
||||
@ -1477,15 +1499,21 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
|
||||
BERT_START_DOCSTRING,
|
||||
)
|
||||
class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationLoss):
|
||||
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||
_keys_to_ignore_on_load_missing = [r"pooler"]
|
||||
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
|
||||
_keys_to_ignore_on_load_unexpected = [
|
||||
r"pooler",
|
||||
r"mlm___cls",
|
||||
r"nsp___cls",
|
||||
r"cls.predictions",
|
||||
r"cls.seq_relationship",
|
||||
]
|
||||
_keys_to_ignore_on_load_missing = [r"dropout"]
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
|
||||
self.num_labels = config.num_labels
|
||||
self.bert = TFBertMainLayer(config, name="bert")
|
||||
self.bert = TFBertMainLayer(config, add_pooling_layer=False, name="bert")
|
||||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = tf.keras.layers.Dense(
|
||||
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
|
||||
@ -1571,15 +1599,20 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
|
||||
BERT_START_DOCSTRING,
|
||||
)
|
||||
class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss):
|
||||
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||
_keys_to_ignore_on_load_missing = [r"pooler"]
|
||||
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
|
||||
_keys_to_ignore_on_load_unexpected = [
|
||||
r"pooler",
|
||||
r"mlm___cls",
|
||||
r"nsp___cls",
|
||||
r"cls.predictions",
|
||||
r"cls.seq_relationship",
|
||||
]
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
|
||||
self.num_labels = config.num_labels
|
||||
self.bert = TFBertMainLayer(config, name="bert")
|
||||
self.bert = TFBertMainLayer(config, add_pooling_layer=False, name="bert")
|
||||
self.qa_outputs = tf.keras.layers.Dense(
|
||||
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
|
||||
)
|
||||
|
@ -468,6 +468,9 @@ class TFElectraPreTrainedModel(TFPreTrainedModel):
|
||||
|
||||
config_class = ElectraConfig
|
||||
base_model_prefix = "electra"
|
||||
# When the model is loaded from a PT model
|
||||
_keys_to_ignore_on_load_unexpected = [r"generator_lm_head.weight"]
|
||||
_keys_to_ignore_on_load_missing = [r"dropout"]
|
||||
|
||||
|
||||
@keras_serializable
|
||||
|
@ -1452,7 +1452,6 @@ class TFFunnelForSequenceClassification(TFFunnelPreTrainedModel, TFSequenceClass
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
last_hidden_state = outputs[0]
|
||||
pooled_output = last_hidden_state[:, 0]
|
||||
logits = self.classifier(pooled_output, training=inputs["training"])
|
||||
@ -1735,7 +1734,6 @@ class TFFunnelForQuestionAnswering(TFFunnelPreTrainedModel, TFQuestionAnsweringL
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.funnel.return_dict
|
||||
outputs = self.funnel(
|
||||
inputs["input_ids"],
|
||||
inputs["attention_mask"],
|
||||
|
@ -413,6 +413,8 @@ class TFGPT2PreTrainedModel(TFPreTrainedModel):
|
||||
|
||||
config_class = GPT2Config
|
||||
base_model_prefix = "transformer"
|
||||
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
|
||||
_keys_to_ignore_on_load_unexpected = [r"h.\d+.attn.bias"]
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -1566,7 +1566,7 @@ class TFLongformerEncoder(tf.keras.layers.Layer):
|
||||
class TFLongformerMainLayer(tf.keras.layers.Layer):
|
||||
config_class = LongformerConfig
|
||||
|
||||
def __init__(self, config, **kwargs):
|
||||
def __init__(self, config, add_pooling_layer=True, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if isinstance(config.attention_window, int):
|
||||
@ -1589,7 +1589,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
|
||||
self.attention_window = config.attention_window
|
||||
self.embeddings = TFLongformerEmbeddings(config, name="embeddings")
|
||||
self.encoder = TFLongformerEncoder(config, name="encoder")
|
||||
self.pooler = TFLongformerPooler(config, name="pooler")
|
||||
self.pooler = TFLongformerPooler(config, name="pooler") if add_pooling_layer else None
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings
|
||||
@ -1710,7 +1710,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
|
||||
training=inputs["training"],
|
||||
)
|
||||
sequence_output = encoder_outputs[0]
|
||||
pooled_output = self.pooler(sequence_output)
|
||||
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
||||
|
||||
# undo padding
|
||||
if padding_len > 0:
|
||||
@ -1997,13 +1997,13 @@ class TFLongformerModel(TFLongformerPreTrainedModel):
|
||||
LONGFORMER_START_DOCSTRING,
|
||||
)
|
||||
class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModelingLoss):
|
||||
|
||||
_keys_to_ignore_on_load_missing = [r"pooler"]
|
||||
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
|
||||
self.longformer = TFLongformerMainLayer(config, name="longformer")
|
||||
self.longformer = TFLongformerMainLayer(config, add_pooling_layer=False, name="longformer")
|
||||
self.lm_head = TFLongformerLMHead(config, self.longformer.embeddings, name="lm_head")
|
||||
|
||||
def get_output_embeddings(self):
|
||||
@ -2091,14 +2091,14 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel
|
||||
LONGFORMER_START_DOCSTRING,
|
||||
)
|
||||
class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAnsweringLoss):
|
||||
|
||||
_keys_to_ignore_on_load_missing = [r"pooler"]
|
||||
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
|
||||
self.num_labels = config.num_labels
|
||||
self.longformer = TFLongformerMainLayer(config, name="longformer")
|
||||
self.longformer = TFLongformerMainLayer(config, add_pooling_layer=False, name="longformer")
|
||||
self.qa_outputs = tf.keras.layers.Dense(
|
||||
config.num_labels,
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
@ -2248,15 +2248,15 @@ class TFLongformerClassificationHead(tf.keras.layers.Layer):
|
||||
LONGFORMER_START_DOCSTRING,
|
||||
)
|
||||
class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSequenceClassificationLoss):
|
||||
|
||||
_keys_to_ignore_on_load_missing = [r"pooler"]
|
||||
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.longformer = TFLongformerMainLayer(config, name="longformer")
|
||||
self.longformer = TFLongformerMainLayer(config, add_pooling_layer=False, name="longformer")
|
||||
self.classifier = TFLongformerClassificationHead(config, name="classifier")
|
||||
|
||||
@add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@ -2346,6 +2346,9 @@ class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSeque
|
||||
LONGFORMER_START_DOCSTRING,
|
||||
)
|
||||
class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoiceLoss):
|
||||
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
|
||||
_keys_to_ignore_on_load_missing = [r"dropout"]
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
|
||||
@ -2478,14 +2481,15 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic
|
||||
LONGFORMER_START_DOCSTRING,
|
||||
)
|
||||
class TFLongformerForTokenClassification(TFLongformerPreTrainedModel, TFTokenClassificationLoss):
|
||||
|
||||
_keys_to_ignore_on_load_missing = [r"pooler"]
|
||||
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||
_keys_to_ignore_on_load_missing = [r"dropout"]
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
|
||||
self.num_labels = config.num_labels
|
||||
self.longformer = TFLongformerMainLayer(config=config, name="longformer")
|
||||
self.longformer = TFLongformerMainLayer(config=config, add_pooling_layer=False, name="longformer")
|
||||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = tf.keras.layers.Dense(
|
||||
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
|
||||
|
@ -754,7 +754,6 @@ class TFLxmertMainLayer(tf.keras.layers.Layer):
|
||||
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if inputs["visual_pos"] is None or inputs["visual_feats"] is None:
|
||||
raise ValueError("visual_feats and visual_pos cannot be `None` in LXMERT's `call` method.")
|
||||
|
||||
|
@ -686,7 +686,7 @@ class TFMobileBertMLMHead(tf.keras.layers.Layer):
|
||||
class TFMobileBertMainLayer(tf.keras.layers.Layer):
|
||||
config_class = MobileBertConfig
|
||||
|
||||
def __init__(self, config, **kwargs):
|
||||
def __init__(self, config, add_pooling_layer=True, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.config = config
|
||||
@ -697,7 +697,7 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer):
|
||||
|
||||
self.embeddings = TFMobileBertEmbeddings(config, name="embeddings")
|
||||
self.encoder = TFMobileBertEncoder(config, name="encoder")
|
||||
self.pooler = TFMobileBertPooler(config, name="pooler")
|
||||
self.pooler = TFMobileBertPooler(config, name="pooler") if add_pooling_layer else None
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings
|
||||
@ -801,7 +801,7 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer):
|
||||
)
|
||||
|
||||
sequence_output = encoder_outputs[0]
|
||||
pooled_output = self.pooler(sequence_output)
|
||||
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
return (
|
||||
@ -1102,13 +1102,18 @@ class TFMobileBertForPreTraining(TFMobileBertPreTrainedModel):
|
||||
|
||||
@add_start_docstrings("""MobileBert Model with a `language modeling` head on top. """, MOBILEBERT_START_DOCSTRING)
|
||||
class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModelingLoss):
|
||||
|
||||
_keys_to_ignore_on_load_missing = [r"pooler"]
|
||||
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
|
||||
_keys_to_ignore_on_load_unexpected = [
|
||||
r"pooler",
|
||||
r"seq_relationship___cls",
|
||||
r"predictions___cls",
|
||||
r"cls.seq_relationship",
|
||||
]
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
|
||||
self.mobilebert = TFMobileBertMainLayer(config, name="mobilebert")
|
||||
self.mobilebert = TFMobileBertMainLayer(config, add_pooling_layer=False, name="mobilebert")
|
||||
self.mlm = TFMobileBertMLMHead(config, name="mlm___cls")
|
||||
|
||||
def get_output_embeddings(self):
|
||||
@ -1170,7 +1175,6 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
prediction_scores = self.mlm(sequence_output, training=inputs["training"])
|
||||
|
||||
@ -1203,6 +1207,9 @@ class TFMobileBertOnlyNSPHead(tf.keras.layers.Layer):
|
||||
MOBILEBERT_START_DOCSTRING,
|
||||
)
|
||||
class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel, TFNextSentencePredictionLoss):
|
||||
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
|
||||
_keys_to_ignore_on_load_unexpected = [r"predictions___cls", r"cls.predictions"]
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
|
||||
@ -1300,6 +1307,15 @@ class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel, TFNextS
|
||||
MOBILEBERT_START_DOCSTRING,
|
||||
)
|
||||
class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSequenceClassificationLoss):
|
||||
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
|
||||
_keys_to_ignore_on_load_unexpected = [
|
||||
r"predictions___cls",
|
||||
r"seq_relationship___cls",
|
||||
r"cls.predictions",
|
||||
r"cls.seq_relationship",
|
||||
]
|
||||
_keys_to_ignore_on_load_missing = [r"dropout"]
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
self.num_labels = config.num_labels
|
||||
@ -1393,14 +1409,20 @@ class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSeque
|
||||
MOBILEBERT_START_DOCSTRING,
|
||||
)
|
||||
class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAnsweringLoss):
|
||||
|
||||
_keys_to_ignore_on_load_missing = [r"pooler"]
|
||||
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
|
||||
_keys_to_ignore_on_load_unexpected = [
|
||||
r"pooler",
|
||||
r"predictions___cls",
|
||||
r"seq_relationship___cls",
|
||||
r"cls.predictions",
|
||||
r"cls.seq_relationship",
|
||||
]
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.mobilebert = TFMobileBertMainLayer(config, name="mobilebert")
|
||||
self.mobilebert = TFMobileBertMainLayer(config, add_pooling_layer=False, name="mobilebert")
|
||||
self.qa_outputs = tf.keras.layers.Dense(
|
||||
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
|
||||
)
|
||||
@ -1501,6 +1523,15 @@ class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAn
|
||||
MOBILEBERT_START_DOCSTRING,
|
||||
)
|
||||
class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoiceLoss):
|
||||
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
|
||||
_keys_to_ignore_on_load_unexpected = [
|
||||
r"predictions___cls",
|
||||
r"seq_relationship___cls",
|
||||
r"cls.predictions",
|
||||
r"cls.seq_relationship",
|
||||
]
|
||||
_keys_to_ignore_on_load_missing = [r"dropout"]
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
|
||||
@ -1628,14 +1659,21 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic
|
||||
MOBILEBERT_START_DOCSTRING,
|
||||
)
|
||||
class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenClassificationLoss):
|
||||
|
||||
_keys_to_ignore_on_load_missing = [r"pooler"]
|
||||
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
|
||||
_keys_to_ignore_on_load_unexpected = [
|
||||
r"pooler",
|
||||
r"predictions___cls",
|
||||
r"seq_relationship___cls",
|
||||
r"cls.predictions",
|
||||
r"cls.seq_relationship",
|
||||
]
|
||||
_keys_to_ignore_on_load_missing = [r"dropout"]
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.mobilebert = TFMobileBertMainLayer(config, name="mobilebert")
|
||||
self.mobilebert = TFMobileBertMainLayer(config, add_pooling_layer=False, name="mobilebert")
|
||||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = tf.keras.layers.Dense(
|
||||
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
|
||||
@ -1696,7 +1734,6 @@ class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenCla
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
|
||||
sequence_output = self.dropout(sequence_output, training=inputs["training"])
|
||||
|
@ -464,7 +464,7 @@ class TFRobertaEncoder(tf.keras.layers.Layer):
|
||||
class TFRobertaMainLayer(tf.keras.layers.Layer):
|
||||
config_class = RobertaConfig
|
||||
|
||||
def __init__(self, config, **kwargs):
|
||||
def __init__(self, config, add_pooling_layer=True, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.config = config
|
||||
@ -474,7 +474,7 @@ class TFRobertaMainLayer(tf.keras.layers.Layer):
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
self.return_dict = config.use_return_dict
|
||||
self.encoder = TFRobertaEncoder(config, name="encoder")
|
||||
self.pooler = TFRobertaPooler(config, name="pooler")
|
||||
self.pooler = TFRobertaPooler(config, name="pooler") if add_pooling_layer else None
|
||||
# The embeddings must be the last declaration in order to follow the weights order
|
||||
self.embeddings = TFRobertaEmbeddings(config, name="embeddings")
|
||||
|
||||
@ -586,7 +586,7 @@ class TFRobertaMainLayer(tf.keras.layers.Layer):
|
||||
)
|
||||
|
||||
sequence_output = encoder_outputs[0]
|
||||
pooled_output = self.pooler(sequence_output)
|
||||
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
return (
|
||||
@ -798,13 +798,13 @@ class TFRobertaLMHead(tf.keras.layers.Layer):
|
||||
|
||||
@add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """, ROBERTA_START_DOCSTRING)
|
||||
class TFRobertaForMaskedLM(TFRobertaPreTrainedModel, TFMaskedLanguageModelingLoss):
|
||||
|
||||
_keys_to_ignore_on_load_missing = [r"pooler"]
|
||||
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler", r"lm_head.decoder.weight"]
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
|
||||
self.roberta = TFRobertaMainLayer(config, name="roberta")
|
||||
self.roberta = TFRobertaMainLayer(config, add_pooling_layer=False, name="roberta")
|
||||
self.lm_head = TFRobertaLMHead(config, self.roberta.embeddings, name="lm_head")
|
||||
|
||||
def get_output_embeddings(self):
|
||||
@ -917,14 +917,14 @@ class TFRobertaClassificationHead(tf.keras.layers.Layer):
|
||||
ROBERTA_START_DOCSTRING,
|
||||
)
|
||||
class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceClassificationLoss):
|
||||
|
||||
_keys_to_ignore_on_load_missing = [r"pooler"]
|
||||
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler", r"lm_head"]
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.roberta = TFRobertaMainLayer(config, name="roberta")
|
||||
self.roberta = TFRobertaMainLayer(config, add_pooling_layer=False, name="roberta")
|
||||
self.classifier = TFRobertaClassificationHead(config, name="classifier")
|
||||
|
||||
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@ -983,7 +983,6 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceCla
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
logits = self.classifier(sequence_output, training=inputs["training"])
|
||||
|
||||
@ -1009,6 +1008,10 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceCla
|
||||
ROBERTA_START_DOCSTRING,
|
||||
)
|
||||
class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss):
|
||||
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
|
||||
_keys_to_ignore_on_load_unexpected = [r"lm_head"]
|
||||
_keys_to_ignore_on_load_missing = [r"dropout"]
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
|
||||
@ -1129,14 +1132,15 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss)
|
||||
ROBERTA_START_DOCSTRING,
|
||||
)
|
||||
class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassificationLoss):
|
||||
|
||||
_keys_to_ignore_on_load_missing = [r"pooler"]
|
||||
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler", r"lm_head"]
|
||||
_keys_to_ignore_on_load_missing = [r"dropout"]
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.roberta = TFRobertaMainLayer(config, name="roberta")
|
||||
self.roberta = TFRobertaMainLayer(config, add_pooling_layer=False, name="roberta")
|
||||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = tf.keras.layers.Dense(
|
||||
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
|
||||
@ -1224,14 +1228,14 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific
|
||||
ROBERTA_START_DOCSTRING,
|
||||
)
|
||||
class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnsweringLoss):
|
||||
|
||||
_keys_to_ignore_on_load_missing = [r"pooler"]
|
||||
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler", r"lm_head"]
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.roberta = TFRobertaMainLayer(config, name="roberta")
|
||||
self.roberta = TFRobertaMainLayer(config, add_pooling_layer=False, name="roberta")
|
||||
self.qa_outputs = tf.keras.layers.Dense(
|
||||
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
|
||||
)
|
||||
|
@ -788,6 +788,8 @@ class TFT5PreTrainedModel(TFPreTrainedModel):
|
||||
|
||||
config_class = T5Config
|
||||
base_model_prefix = "transformer"
|
||||
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
|
||||
_keys_to_ignore_on_load_unexpected = [r"decoder\Wblock[\W_0]+layer[\W_1]+EncDecAttention\Wrelative_attention_bias"]
|
||||
|
||||
@property
|
||||
def dummy_inputs(self):
|
||||
|
@ -16,6 +16,7 @@
|
||||
"""
|
||||
TF 2.0 Transformer XL model.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
|
@ -963,7 +963,7 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
output = transformer_outputs[0]
|
||||
|
@ -1667,6 +1667,7 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
|
||||
Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
|
||||
1]``.
|
||||
"""
|
||||
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
@ -1703,7 +1704,6 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
output = transformer_outputs[0]
|
||||
logits = self.classifier(output)
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
|
||||
|
@ -167,14 +167,14 @@ class TFAutoModelTest(unittest.TestCase):
|
||||
def test_from_pretrained_identifier(self):
|
||||
model = TFAutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER)
|
||||
self.assertIsInstance(model, TFBertForMaskedLM)
|
||||
self.assertEqual(model.num_parameters(), 14830)
|
||||
self.assertEqual(model.num_parameters(only_trainable=True), 14830)
|
||||
self.assertEqual(model.num_parameters(), 14410)
|
||||
self.assertEqual(model.num_parameters(only_trainable=True), 14410)
|
||||
|
||||
def test_from_identifier_from_model_type(self):
|
||||
model = TFAutoModelWithLMHead.from_pretrained(DUMMY_UNKWOWN_IDENTIFIER)
|
||||
self.assertIsInstance(model, TFRobertaForMaskedLM)
|
||||
self.assertEqual(model.num_parameters(), 14830)
|
||||
self.assertEqual(model.num_parameters(only_trainable=True), 14830)
|
||||
self.assertEqual(model.num_parameters(), 14410)
|
||||
self.assertEqual(model.num_parameters(only_trainable=True), 14410)
|
||||
|
||||
def test_parents_and_children_in_mappings(self):
|
||||
# Test that the children are placed before the parents in the mappings, as the `instanceof` will be triggered
|
||||
|
@ -335,7 +335,7 @@ class TFBertModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
model, output_loading_info = TFBertForTokenClassification.from_pretrained(
|
||||
"jplu/tiny-tf-bert-random", output_loading_info=True
|
||||
)
|
||||
self.assertEqual(sorted(output_loading_info["unexpected_keys"]), ["mlm___cls", "nsp___cls"])
|
||||
self.assertEqual(sorted(output_loading_info["unexpected_keys"]), [])
|
||||
for layer in output_loading_info["missing_keys"]:
|
||||
self.assertTrue(layer.split("_")[0] in ["dropout", "classifier"])
|
||||
|
||||
|
@ -223,8 +223,8 @@ class TFPTAutoModelTest(unittest.TestCase):
|
||||
def test_from_pretrained_identifier(self):
|
||||
model = TFAutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER, from_pt=True)
|
||||
self.assertIsInstance(model, TFBertForMaskedLM)
|
||||
self.assertEqual(model.num_parameters(), 14830)
|
||||
self.assertEqual(model.num_parameters(only_trainable=True), 14830)
|
||||
self.assertEqual(model.num_parameters(), 14410)
|
||||
self.assertEqual(model.num_parameters(only_trainable=True), 14410)
|
||||
|
||||
model = AutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER, from_tf=True)
|
||||
self.assertIsInstance(model, BertForMaskedLM)
|
||||
@ -234,8 +234,8 @@ class TFPTAutoModelTest(unittest.TestCase):
|
||||
def test_from_identifier_from_model_type(self):
|
||||
model = TFAutoModelWithLMHead.from_pretrained(DUMMY_UNKWOWN_IDENTIFIER, from_pt=True)
|
||||
self.assertIsInstance(model, TFRobertaForMaskedLM)
|
||||
self.assertEqual(model.num_parameters(), 14830)
|
||||
self.assertEqual(model.num_parameters(only_trainable=True), 14830)
|
||||
self.assertEqual(model.num_parameters(), 14410)
|
||||
self.assertEqual(model.num_parameters(only_trainable=True), 14410)
|
||||
|
||||
model = AutoModelWithLMHead.from_pretrained(DUMMY_UNKWOWN_IDENTIFIER, from_tf=True)
|
||||
self.assertIsInstance(model, RobertaForMaskedLM)
|
||||
|
Loading…
Reference in New Issue
Block a user