RFC: Replace custom TF embeddings by Keras embeddings (#18939)

This commit is contained in:
Joao Gante 2022-09-10 11:34:49 +01:00 committed by GitHub
parent 855dcae8bb
commit 00cbadb870
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 141 additions and 142 deletions

View File

@ -887,6 +887,12 @@ def load_tf_weights(model, resolved_archive_file, ignore_mismatched_sizes=False,
# If not, make the value to None
saved_weight_value = saved_weights.get(symbolic_weight_name, None)
# Retrocompatibility patch: some embeddings are stored with the weights name (e.g. Bart's
# `model.shared/embeddings:0` are stored as `model.shared/weights:0`)
if saved_weight_value is None and symbolic_weight_name.endswith("embeddings:0"):
symbolic_weight_name = symbolic_weight_name[:-12] + "weight:0"
saved_weight_value = saved_weights.get(symbolic_weight_name, None)
# Add the updated name to the final list for computing missing/unexpected values
symbolic_weights_names.add(symbolic_weight_name)
@ -1700,7 +1706,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
"""
return None
def resize_token_embeddings(self, new_num_tokens=None) -> tf.Variable:
def resize_token_embeddings(
self, new_num_tokens: Optional[int] = None
) -> Union[tf.keras.layers.Embedding, tf.Variable]:
"""
Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`.
@ -1710,11 +1718,17 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
new_num_tokens (`int`, *optional*):
The number of new tokens in the embedding matrix. Increasing the size will add newly initialized
vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just
returns a pointer to the input tokens `tf.Variable` module of the model without doing anything.
returns a pointer to the input tokens without doing anything.
Return:
`tf.Variable`: Pointer to the input tokens Embeddings Module of the model.
`tf.Variable` or `tf.keras.layers.Embedding`: Pointer to the input tokens of the model.
"""
# TODO (joao): flagged for replacement (by `_v2_resized_token_embeddings`) due to embeddings refactor
# Run the new code path if the model has a keras embeddings layer
if isinstance(self.get_input_embeddings(), tf.keras.layers.Embedding):
return self._v2_resized_token_embeddings(new_num_tokens)
if new_num_tokens is None or new_num_tokens == self.config.vocab_size:
return self._get_word_embedding_weight(self.get_input_embeddings())
@ -1725,7 +1739,32 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
return model_embeds
def _v2_resized_token_embeddings(self, new_num_tokens: Optional[int] = None) -> tf.keras.layers.Embedding:
"""
Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`.
Arguments:
new_num_tokens (`int`, *optional*):
The number of new tokens in the embedding matrix. Increasing the size will add newly initialized
vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just
returns a pointer to the input tokens without doing anything.
Return:
`tf.keras.layers.Embedding`: Pointer to the input tokens of the model.
"""
if new_num_tokens is None or new_num_tokens == self.config.vocab_size:
return self.get_input_embeddings()
model_embeds = self._v2_resize_token_embeddings(new_num_tokens)
# Update base model and current model config
self.config.vocab_size = new_num_tokens
return model_embeds
def _get_word_embedding_weight(model, embedding_layer):
# TODO (joao): flagged for delection due to embeddings refactor
# If the variable holds the weights themselves, return them
if isinstance(embedding_layer, tf.Tensor):
return embedding_layer
@ -1755,6 +1794,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
return None
def _resize_token_embeddings(self, new_num_tokens):
# TODO (joao): flagged for replacement (by `_v2_resize_token_embeddings`) due to embeddings refactor
old_embeddings = self._get_word_embedding_weight(self.get_input_embeddings())
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
@ -1776,6 +1816,27 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
return self.get_input_embeddings()
def _v2_resize_token_embeddings(self, new_num_tokens):
old_embeddings = self.get_input_embeddings()
new_embeddings = self._v2_get_resized_embeddings(old_embeddings, new_num_tokens)
self.set_input_embeddings(new_embeddings)
# If word embeddings are not tied, make sure that lm head bias is resized as well
if self.get_bias() is not None:
old_lm_head_bias = self.get_bias()
new_lm_head_bias = self._get_resized_lm_head_bias(old_lm_head_bias, new_num_tokens)
self.set_bias(new_lm_head_bias)
# If word embeddings are not tied, make sure that lm head decoder is resized as well.
tied_weights = self.get_input_embeddings() == self.get_output_embeddings()
if self.get_output_embeddings() is not None and not tied_weights:
old_lm_head_decoder = self._get_word_embedding_weight(self.get_output_embeddings())
# TODO (joao): this one probably needs a v2 version with other models
new_lm_head_decoder = self._get_resized_lm_head_decoder(old_lm_head_decoder, new_num_tokens)
self.set_output_embeddings(new_lm_head_decoder)
return self.get_input_embeddings()
def _get_resized_lm_head_bias(self, old_lm_head_bias, new_num_tokens):
"""
Build a resized bias from the old ones. Increasing the size will add newly initialized vectors at the end.
@ -1885,6 +1946,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
`tf.Variable`: Pointer to the resized Embedding Module or the old Embedding Module if `new_num_tokens` is
`None`
"""
# TODO (joao): flagged for replacement (by `_v2_get_resized_embeddings`) due to embeddings refactor
old_embedding_dim = shape_list(old_embeddings)[1]
init_range = getattr(self.config, "initializer_range", 0.02)
embeddings_mask, current_embeddings = init_copy_embeddings(old_embeddings, new_num_tokens)
@ -1900,6 +1962,42 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
return new_embeddings
def _v2_get_resized_embeddings(
self, old_embeddings: tf.keras.layers.Embedding, new_num_tokens: int
) -> tf.keras.layers.Embedding:
"""
Build a resized Embedding layer from a provided Embedding layer. Increasing the size will add newly initialized
vectors at the end. Reducing the size will remove vectors from the end.
Args:
old_embeddings (`tf.keras.layers.Embedding`):
Old embeddings to be resized.
new_num_tokens (`int`, *optional*):
New number of tokens in the embedding matrix.
Return:
`tf.keras.layers.Embedding`: Resized Embedding layer.
"""
# Get a new (initialized) embeddings layer
init_range = getattr(self.config, "initializer_range", 0.02)
new_embeddings = tf.keras.layers.Embedding(
input_dim=new_num_tokens,
output_dim=old_embeddings.output_dim,
embeddings_initializer=get_initializer(init_range),
name=old_embeddings.embeddings.name[:-13], # exact same scoped name except "/embeddings:0"
)
new_embeddings(tf.constant([[0]]))
# Copy the old embeddings to the new embeddings
if old_embeddings.input_dim >= new_num_tokens:
init_embeddings = old_embeddings.embeddings[:new_num_tokens]
else:
init_embeddings = tf.concat(
[old_embeddings.embeddings, new_embeddings.embeddings[old_embeddings.input_dim :]], axis=0
)
new_embeddings.embeddings.assign(init_embeddings)
return new_embeddings
def prune_heads(self, heads_to_prune):
"""
Prunes heads of the base model.
@ -2632,6 +2730,7 @@ class TFSharedEmbeddings(tf.keras.layers.Layer):
kwargs:
Additional keyword arguments passed along to the `__init__` of `tf.keras.layers.Layer`.
"""
# TODO (joao): flagged for delection due to embeddings refactor
def __init__(self, vocab_size: int, hidden_size: int, initializer_range: Optional[float] = None, **kwargs):
super().__init__(**kwargs)
@ -2848,6 +2947,8 @@ class TFWrappedEmbeddings:
saving/storing the correct weights
"""
# TODO (joao): flagged for delection due to embeddings refactor
def __init__(self, layer, abs_scope_name=None):
self._layer = layer
self._abs_scope_name = abs_scope_name

View File

@ -35,8 +35,6 @@ from ...modeling_tf_utils import (
TFCausalLanguageModelingLoss,
TFModelInputType,
TFPreTrainedModel,
TFSharedEmbeddings,
TFWrappedEmbeddings,
keras_serializable,
unpack_inputs,
)
@ -113,7 +111,7 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
return (one_cst - expanded_mask) * LARGE_NEGATIVE
class TFBartLearnedPositionalEmbedding(TFSharedEmbeddings):
class TFBartLearnedPositionalEmbedding(tf.keras.layers.Embedding):
"""
This module learns positional embeddings up to a fixed maximum size.
"""
@ -136,7 +134,8 @@ class TFBartLearnedPositionalEmbedding(TFSharedEmbeddings):
position_ids = tf.range(seq_len, delta=1, name="range")
position_ids += past_key_values_length
return super().call(position_ids + self.offset)
offset_dtype = position_ids.dtype if isinstance(position_ids, tf.Tensor) else tf.int32
return super().call(position_ids + tf.constant(self.offset, dtype=offset_dtype))
class TFBartAttention(tf.keras.layers.Layer):
@ -667,7 +666,7 @@ class TFBartEncoder(tf.keras.layers.Layer):
config: BartConfig
"""
def __init__(self, config: BartConfig, embed_tokens: Optional[TFSharedEmbeddings] = None, **kwargs):
def __init__(self, config: BartConfig, embed_tokens: Optional[tf.keras.layers.Embedding] = None, **kwargs):
super().__init__(**kwargs)
self.config = config
self.dropout = tf.keras.layers.Dropout(config.dropout)
@ -685,12 +684,6 @@ class TFBartEncoder(tf.keras.layers.Layer):
self.layers = [TFBartEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)]
self.layernorm_embedding = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding")
def get_embed_tokens(self):
return self.embed_tokens
def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens
@unpack_inputs
def call(
self,
@ -750,7 +743,8 @@ class TFBartEncoder(tf.keras.layers.Layer):
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
with tf.name_scope(self.embed_tokens.name + "/"):
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
embed_pos = self.embed_positions(input_shape)
hidden_states = inputs_embeds + embed_pos
@ -820,7 +814,7 @@ class TFBartDecoder(tf.keras.layers.Layer):
embed_tokens: output embedding
"""
def __init__(self, config: BartConfig, embed_tokens: Optional[TFSharedEmbeddings] = None, **kwargs):
def __init__(self, config: BartConfig, embed_tokens: Optional[tf.keras.layers.Embedding] = None, **kwargs):
super().__init__(**kwargs)
self.config = config
self.padding_idx = config.pad_token_id
@ -837,12 +831,6 @@ class TFBartDecoder(tf.keras.layers.Layer):
self.dropout = tf.keras.layers.Dropout(config.dropout)
def get_embed_tokens(self):
return self.embed_tokens
def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens
@unpack_inputs
def call(
self,
@ -943,7 +931,8 @@ class TFBartDecoder(tf.keras.layers.Layer):
positions = self.embed_positions(input_shape, position_ids=position_ids)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
with tf.name_scope(self.embed_tokens.name + "/"):
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
hidden_states = inputs_embeds
@ -1038,36 +1027,19 @@ class TFBartMainLayer(tf.keras.layers.Layer):
def __init__(self, config: BartConfig, load_weight_prefix=None, **kwargs):
super().__init__(**kwargs)
self.config = config
self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, config.pad_token_id, name="model.shared")
load_weight_prefix = "model.shared" if load_weight_prefix is None else load_weight_prefix
self.shared = tf.keras.layers.Embedding(config.vocab_size, config.d_model, name=load_weight_prefix)
# set tf scope correctly
if load_weight_prefix is None:
load_weight_prefix = "model.shared"
with tf.compat.v1.variable_scope(load_weight_prefix) as shared_abs_scope_name:
pass
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
embed_tokens.vocab_size = self.shared.vocab_size
embed_tokens.hidden_size = self.shared.hidden_size
self.encoder = TFBartEncoder(config, embed_tokens, name="encoder")
self.decoder = TFBartDecoder(config, embed_tokens, name="decoder")
self.encoder = TFBartEncoder(config, self.shared, name="encoder")
self.decoder = TFBartDecoder(config, self.shared, name="decoder")
def get_input_embeddings(self):
return self.shared
def set_input_embeddings(self, new_embeddings):
self.shared.weight = new_embeddings
self.shared.vocab_size = self.shared.weight.shape[0]
# retrieve correct absolute scope for embed token wrapper
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
pass
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
self.encoder.set_embed_tokens(embed_tokens)
self.decoder.set_embed_tokens(embed_tokens)
self.shared = new_embeddings
self.encoder.embed_tokens = self.shared
self.decoder.embed_tokens = self.shared
@unpack_inputs
def call(
@ -1273,11 +1245,7 @@ class BiasLayer(tf.keras.layers.Layer):
BART_START_DOCSTRING,
)
class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageModelingLoss):
_keys_to_ignore_on_load_unexpected = [
r"model.encoder.embed_tokens.weight",
r"model.decoder.embed_tokens.weight",
]
_keys_to_ignore_on_load_missing = [r"final_logits_bias"]
_requires_load_weight_prefix = True
def __init__(self, config, load_weight_prefix=None, *inputs, **kwargs):
@ -1303,10 +1271,10 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
self.set_input_embeddings(value)
def get_bias(self):
return {"final_logits_bias": self.final_logits_bias}
return {"final_logits_bias": self.bias_layer.bias}
def set_bias(self, value):
self.final_logits_bias = value["final_logits_bias"]
self.bias_layer.bias = value["final_logits_bias"]
@add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
@ -1374,7 +1342,9 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
return_dict=return_dict,
training=training,
)
lm_logits = self.model.shared(outputs[0], mode="linear")
# TODO (joao): the line below is for models with tied embeddings. The previous TFBart had tied embeddings.
# The PT Bart does not have tied embeddings. Untie the weights while keeping loading retrocompatibility.
lm_logits = tf.matmul(outputs[0], self.model.shared.weights, transpose_b=True)
lm_logits = self.bias_layer(lm_logits)
masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits)

View File

@ -137,7 +137,8 @@ class TFMBartLearnedPositionalEmbedding(TFSharedEmbeddings):
position_ids = tf.range(seq_len, delta=1, name="range")
position_ids += past_key_values_length
return super().call(position_ids + self.offset)
offset_dtype = position_ids.dtype if isinstance(position_ids, tf.Tensor) else tf.int32
return super().call(position_ids + tf.constant(self.offset, dtype=offset_dtype))
# Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with Bart->MBart

View File

@ -230,69 +230,6 @@ class TFBartModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestC
name = model.get_bias()
assert name is None
def test_resize_token_embeddings(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
def _get_word_embedding_weight(model, embedding_layer):
if hasattr(embedding_layer, "weight"):
return embedding_layer.weight
else:
# Here we build the word embeddings weights if not exists.
# And then we retry to get the attribute once built.
model(model.dummy_inputs)
if hasattr(embedding_layer, "weight"):
return embedding_layer.weight
else:
return None
for model_class in self.all_model_classes:
for size in [config.vocab_size - 10, config.vocab_size + 10, None]:
# build the embeddings
model = model_class(config=config)
old_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings())
old_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings())
old_final_logits_bias = model.get_bias()
# reshape the embeddings
model.resize_token_embeddings(size)
new_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings())
new_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings())
new_final_logits_bias = model.get_bias()
# check that the resized embeddings size matches the desired size.
assert_size = size if size is not None else config.vocab_size
self.assertEqual(new_input_embeddings.shape[0], assert_size)
# check that weights remain the same after resizing
models_equal = True
for p1, p2 in zip(old_input_embeddings.value(), new_input_embeddings.value()):
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
models_equal = False
self.assertTrue(models_equal)
if old_output_embeddings is not None and new_output_embeddings is not None:
self.assertEqual(new_output_embeddings.shape[0], assert_size)
models_equal = True
for p1, p2 in zip(old_output_embeddings.value(), new_output_embeddings.value()):
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
models_equal = False
self.assertTrue(models_equal)
if old_final_logits_bias is not None and new_final_logits_bias is not None:
old_final_logits_bias = old_final_logits_bias["final_logits_bias"]
new_final_logits_bias = new_final_logits_bias["final_logits_bias"]
self.assertEqual(new_final_logits_bias.shape[0], 1)
self.assertEqual(new_final_logits_bias.shape[1], assert_size)
models_equal = True
for old, new in zip(old_final_logits_bias.value(), new_final_logits_bias.value()):
for p1, p2 in zip(old, new):
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
models_equal = False
self.assertTrue(models_equal)
@tooslow
def test_saved_model_creation(self):
pass
@ -635,7 +572,7 @@ class FasterTFBartModelIntegrationTests(unittest.TestCase):
def test_xsum_1_1_generation(self):
model = self.xsum_1_1_model
assert model.model.decoder.embed_tokens._layer == model.model.shared
assert model.model.decoder.embed_tokens == model.model.shared
ARTICLE = (
"The Palestinian Authority officially became the 123rd member of the International Criminal Court on"
" Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The"
@ -685,7 +622,7 @@ class FasterTFBartModelIntegrationTests(unittest.TestCase):
def test_xsum_1_1_xla_generation(self):
# same test as above, but with `no_repeat_ngram_size=0` (not compatible with XLA) and XLA comparison enabled
model = self.xsum_1_1_model
assert model.model.decoder.embed_tokens._layer == model.model.shared
assert model.model.decoder.embed_tokens == model.model.shared
ARTICLE = (
"The Palestinian Authority officially became the 123rd member of the International Criminal Court on"
" Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The"

View File

@ -1144,30 +1144,20 @@ class TFModelTesterMixin:
self.assert_outputs_same(output_for_dict_input, output_for_kw_input)
def test_resize_token_embeddings(self):
# TODO (joao): after the embeddings refactor is complete, rework this test so as to rely exclusively on
# tf.keras.layers.Embedding
if not self.test_resize_embeddings:
return
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
def _get_word_embedding_weight(model, embedding_layer):
embeds = getattr(embedding_layer, "weight", None)
if embeds is not None:
return embeds
embeds = getattr(embedding_layer, "decoder", None)
if embeds is not None:
return embeds
model(model.dummy_inputs)
embeds = getattr(embedding_layer, "weight", None)
if embeds is not None:
return embeds
embeds = getattr(embedding_layer, "decoder", None)
if embeds is not None:
return embeds
return None
if isinstance(embedding_layer, tf.keras.layers.Embedding):
# builds the embeddings layer
model(model.dummy_inputs)
return embedding_layer.embeddings
else:
return model._get_word_embedding_weight(embedding_layer)
for model_class in self.all_model_classes:
for size in [config.vocab_size - 10, config.vocab_size + 10, None]:
@ -1195,10 +1185,10 @@ class TFModelTesterMixin:
if old_bias is not None and new_bias is not None:
for old_weight, new_weight in zip(old_bias.values(), new_bias.values()):
self.assertEqual(new_weight.shape[0], assert_size)
self.assertEqual(new_weight.shape[-1], assert_size)
models_equal = True
for p1, p2 in zip(old_weight.value(), new_weight.value()):
for p1, p2 in zip(tf.squeeze(old_weight), tf.squeeze(new_weight)):
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
models_equal = False
self.assertTrue(models_equal)