Fix naming (#10095)

This commit is contained in:
Julien Plu 2021-02-09 12:10:31 +01:00 committed by GitHub
parent 4ed763779e
commit c6d5e56595
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1105,7 +1105,6 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel
_keys_to_ignore_on_load_unexpected = [
r"pooler",
r"seq_relationship___cls",
r"predictions___cls",
r"cls.seq_relationship",
]
@ -1113,10 +1112,10 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel
super().__init__(config, *inputs, **kwargs)
self.mobilebert = TFMobileBertMainLayer(config, add_pooling_layer=False, name="mobilebert")
self.mlm = TFMobileBertMLMHead(config, name="mlm___cls")
self.predictions = TFMobileBertMLMHead(config, name="predictions___cls")
def get_lm_head(self):
return self.mlm.predictions
return self.predictions.predictions
def get_prefix_bias_name(self):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
@ -1179,7 +1178,7 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel
training=inputs["training"],
)
sequence_output = outputs[0]
prediction_scores = self.mlm(sequence_output, training=inputs["training"])
prediction_scores = self.predictions(sequence_output, training=inputs["training"])
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_scores)