mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
TF: TFMarianMTModel final logits bias as a layer (#18833)
* bias as a layer * alias the bias (hah, it rhymes) * add comment with info
This commit is contained in:
parent
65fb71bc76
commit
7f27e002fd
@ -1269,6 +1269,23 @@ class TFMarianModel(TFMarianPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BiasLayer(tf.keras.layers.Layer):
|
||||||
|
"""
|
||||||
|
Bias as a layer. It is used for serialization purposes: `tf.keras.Model.save_weights` stores on a per-layer basis,
|
||||||
|
so all weights have to be registered in a layer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, shape, initializer, trainable, name, **kwargs):
|
||||||
|
super().__init__(name=name, **kwargs)
|
||||||
|
# Note: the name of this variable will NOT be scoped when serialized, i.e. it will not be in the format of
|
||||||
|
# "outer_layer/inner_layer/.../name:0". Instead, it will be "name:0". For further details, see:
|
||||||
|
# https://github.com/huggingface/transformers/pull/18833#issuecomment-1233090214
|
||||||
|
self.bias = self.add_weight(name=name, shape=shape, initializer=initializer, trainable=trainable)
|
||||||
|
|
||||||
|
def call(self, x):
|
||||||
|
return x + self.bias
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"The MARIAN Model with a language modeling head. Can be used for summarization.",
|
"The MARIAN Model with a language modeling head. Can be used for summarization.",
|
||||||
MARIAN_START_DOCSTRING,
|
MARIAN_START_DOCSTRING,
|
||||||
@ -1284,9 +1301,10 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
|
|||||||
self.model = TFMarianMainLayer(config, name="model")
|
self.model = TFMarianMainLayer(config, name="model")
|
||||||
self.use_cache = config.use_cache
|
self.use_cache = config.use_cache
|
||||||
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the sake of consistency.
|
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the sake of consistency.
|
||||||
self.final_logits_bias = self.add_weight(
|
self.bias_layer = BiasLayer(
|
||||||
name="final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False
|
name="final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False
|
||||||
)
|
)
|
||||||
|
self.final_logits_bias = self.bias_layer.bias # alias to keep the same interface with PT
|
||||||
|
|
||||||
def get_decoder(self):
|
def get_decoder(self):
|
||||||
return self.model.decoder
|
return self.model.decoder
|
||||||
@ -1373,7 +1391,7 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
|
|||||||
training=training,
|
training=training,
|
||||||
)
|
)
|
||||||
lm_logits = self.model.shared(outputs[0], mode="linear")
|
lm_logits = self.model.shared(outputs[0], mode="linear")
|
||||||
lm_logits = lm_logits + self.final_logits_bias
|
lm_logits = self.bias_layer(lm_logits)
|
||||||
masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits)
|
masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
|
Loading…
Reference in New Issue
Block a user