mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix naming (#10095)
This commit is contained in:
parent
4ed763779e
commit
c6d5e56595
@ -1105,7 +1105,6 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel
|
||||
_keys_to_ignore_on_load_unexpected = [
|
||||
r"pooler",
|
||||
r"seq_relationship___cls",
|
||||
r"predictions___cls",
|
||||
r"cls.seq_relationship",
|
||||
]
|
||||
|
||||
@ -1113,10 +1112,10 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
|
||||
self.mobilebert = TFMobileBertMainLayer(config, add_pooling_layer=False, name="mobilebert")
|
||||
self.mlm = TFMobileBertMLMHead(config, name="mlm___cls")
|
||||
self.predictions = TFMobileBertMLMHead(config, name="predictions___cls")
|
||||
|
||||
def get_lm_head(self):
|
||||
return self.mlm.predictions
|
||||
return self.predictions.predictions
|
||||
|
||||
def get_prefix_bias_name(self):
|
||||
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
|
||||
@ -1179,7 +1178,7 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel
|
||||
training=inputs["training"],
|
||||
)
|
||||
sequence_output = outputs[0]
|
||||
prediction_scores = self.mlm(sequence_output, training=inputs["training"])
|
||||
prediction_scores = self.predictions(sequence_output, training=inputs["training"])
|
||||
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_scores)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user