Fix weight tying in TF-ESM (#22839)

Fix weight tying in ESM
This commit is contained in:
Matt 2023-04-20 15:50:31 +01:00 committed by GitHub
parent 3b61d2890d
commit 6dc0a849b7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 35 additions and 8 deletions

View File

@ -14,6 +14,7 @@
# limitations under the License.
""" PyTorch ESM model."""
import os
from typing import Optional, Tuple, Union
import numpy as np
@ -1102,6 +1103,11 @@ class TFEsmForMaskedLM(TFEsmPreTrainedModel, TFMaskedLanguageModelingLoss):
self.esm = TFEsmMainLayer(config, add_pooling_layer=False, name="esm")
self.lm_head = TFEsmLMHead(config, name="lm_head")
if config.tie_word_embeddings:
# Ensure word embeddings are built so that we actually have something to tie
with tf.name_scope(os.path.join(self._name_scope(), "esm", "embeddings", "word_embeddings")):
self.esm.embeddings.word_embeddings.build((None, None))
self.lm_head.decoder = self.esm.embeddings.word_embeddings.weights[0]
def get_output_embeddings(self):
return self.lm_head.decoder
@ -1211,18 +1217,22 @@ class TFEsmLMHead(Layer):
self.layer_norm = LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
self.decoder = Dense(
config.vocab_size,
use_bias=False,
kernel_initializer=get_initializer(config.initializer_range),
name="decoder",
)
self.decoder = None
self.config = config
def build(self, input_shape):
super().build(input_shape)
# Separate bias to match the PT model and allow weight cross-loading to work
# Put it in the build so it gets the right name when adding it as a weight
if not self.config.tie_word_embeddings:
if self.decoder is not None:
raise ValueError("Expected decoder not to be initialized before build when not tying weights!")
self.decoder = self.add_weight(
"decoder.weight",
shape=(self.config.hidden_size, self.config.vocab_size),
initializer=get_initializer(self.config.initializer_range),
trainable=True,
)
self.bias = self.add_weight("bias", shape=(self.config.vocab_size,), initializer="zeros", trainable=True)
def get_bias(self):
@ -1234,8 +1244,7 @@ class TFEsmLMHead(Layer):
x = self.layer_norm(x)
# project back to size of vocabulary with bias
x = self.decoder(x)
x = x + self.bias
x = tf.matmul(x, self.decoder, transpose_b=True) + self.bias
return x

View File

@ -262,6 +262,24 @@ class TFEsmModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
def test_save_load_after_resize_token_embeddings(self):
pass
def test_model_common_attributes(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer)
if model_class is TFEsmForMaskedLM:
# Output embedding test differs from the main test because they're a matrix, not a layer
name = model.get_bias()
assert isinstance(name, dict)
for k, v in name.items():
assert isinstance(v, tf.Variable)
else:
x = model.get_output_embeddings()
assert x is None
name = model.get_bias()
assert name is None
@require_tf
class TFEsmModelIntegrationTest(unittest.TestCase):