mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
RFC: Replace custom TF embeddings by Keras embeddings (#18939)
This commit is contained in:
parent
855dcae8bb
commit
00cbadb870
@ -887,6 +887,12 @@ def load_tf_weights(model, resolved_archive_file, ignore_mismatched_sizes=False,
|
||||
# If not, make the value to None
|
||||
saved_weight_value = saved_weights.get(symbolic_weight_name, None)
|
||||
|
||||
# Retrocompatibility patch: some embeddings are stored with the weights name (e.g. Bart's
|
||||
# `model.shared/embeddings:0` are stored as `model.shared/weights:0`)
|
||||
if saved_weight_value is None and symbolic_weight_name.endswith("embeddings:0"):
|
||||
symbolic_weight_name = symbolic_weight_name[:-12] + "weight:0"
|
||||
saved_weight_value = saved_weights.get(symbolic_weight_name, None)
|
||||
|
||||
# Add the updated name to the final list for computing missing/unexpected values
|
||||
symbolic_weights_names.add(symbolic_weight_name)
|
||||
|
||||
@ -1700,7 +1706,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
"""
|
||||
return None
|
||||
|
||||
def resize_token_embeddings(self, new_num_tokens=None) -> tf.Variable:
|
||||
def resize_token_embeddings(
|
||||
self, new_num_tokens: Optional[int] = None
|
||||
) -> Union[tf.keras.layers.Embedding, tf.Variable]:
|
||||
"""
|
||||
Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`.
|
||||
|
||||
@ -1710,11 +1718,17 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
new_num_tokens (`int`, *optional*):
|
||||
The number of new tokens in the embedding matrix. Increasing the size will add newly initialized
|
||||
vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just
|
||||
returns a pointer to the input tokens `tf.Variable` module of the model without doing anything.
|
||||
returns a pointer to the input tokens without doing anything.
|
||||
|
||||
Return:
|
||||
`tf.Variable`: Pointer to the input tokens Embeddings Module of the model.
|
||||
`tf.Variable` or `tf.keras.layers.Embedding`: Pointer to the input tokens of the model.
|
||||
"""
|
||||
# TODO (joao): flagged for replacement (by `_v2_resized_token_embeddings`) due to embeddings refactor
|
||||
|
||||
# Run the new code path if the model has a keras embeddings layer
|
||||
if isinstance(self.get_input_embeddings(), tf.keras.layers.Embedding):
|
||||
return self._v2_resized_token_embeddings(new_num_tokens)
|
||||
|
||||
if new_num_tokens is None or new_num_tokens == self.config.vocab_size:
|
||||
return self._get_word_embedding_weight(self.get_input_embeddings())
|
||||
|
||||
@ -1725,7 +1739,32 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
|
||||
return model_embeds
|
||||
|
||||
def _v2_resized_token_embeddings(self, new_num_tokens: Optional[int] = None) -> tf.keras.layers.Embedding:
|
||||
"""
|
||||
Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`.
|
||||
|
||||
Arguments:
|
||||
new_num_tokens (`int`, *optional*):
|
||||
The number of new tokens in the embedding matrix. Increasing the size will add newly initialized
|
||||
vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just
|
||||
returns a pointer to the input tokens without doing anything.
|
||||
|
||||
Return:
|
||||
`tf.keras.layers.Embedding`: Pointer to the input tokens of the model.
|
||||
"""
|
||||
if new_num_tokens is None or new_num_tokens == self.config.vocab_size:
|
||||
return self.get_input_embeddings()
|
||||
|
||||
model_embeds = self._v2_resize_token_embeddings(new_num_tokens)
|
||||
|
||||
# Update base model and current model config
|
||||
self.config.vocab_size = new_num_tokens
|
||||
|
||||
return model_embeds
|
||||
|
||||
def _get_word_embedding_weight(model, embedding_layer):
|
||||
# TODO (joao): flagged for delection due to embeddings refactor
|
||||
|
||||
# If the variable holds the weights themselves, return them
|
||||
if isinstance(embedding_layer, tf.Tensor):
|
||||
return embedding_layer
|
||||
@ -1755,6 +1794,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
return None
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
# TODO (joao): flagged for replacement (by `_v2_resize_token_embeddings`) due to embeddings refactor
|
||||
old_embeddings = self._get_word_embedding_weight(self.get_input_embeddings())
|
||||
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
|
||||
|
||||
@ -1776,6 +1816,27 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
|
||||
return self.get_input_embeddings()
|
||||
|
||||
def _v2_resize_token_embeddings(self, new_num_tokens):
|
||||
old_embeddings = self.get_input_embeddings()
|
||||
new_embeddings = self._v2_get_resized_embeddings(old_embeddings, new_num_tokens)
|
||||
self.set_input_embeddings(new_embeddings)
|
||||
|
||||
# If word embeddings are not tied, make sure that lm head bias is resized as well
|
||||
if self.get_bias() is not None:
|
||||
old_lm_head_bias = self.get_bias()
|
||||
new_lm_head_bias = self._get_resized_lm_head_bias(old_lm_head_bias, new_num_tokens)
|
||||
self.set_bias(new_lm_head_bias)
|
||||
|
||||
# If word embeddings are not tied, make sure that lm head decoder is resized as well.
|
||||
tied_weights = self.get_input_embeddings() == self.get_output_embeddings()
|
||||
if self.get_output_embeddings() is not None and not tied_weights:
|
||||
old_lm_head_decoder = self._get_word_embedding_weight(self.get_output_embeddings())
|
||||
# TODO (joao): this one probably needs a v2 version with other models
|
||||
new_lm_head_decoder = self._get_resized_lm_head_decoder(old_lm_head_decoder, new_num_tokens)
|
||||
self.set_output_embeddings(new_lm_head_decoder)
|
||||
|
||||
return self.get_input_embeddings()
|
||||
|
||||
def _get_resized_lm_head_bias(self, old_lm_head_bias, new_num_tokens):
|
||||
"""
|
||||
Build a resized bias from the old ones. Increasing the size will add newly initialized vectors at the end.
|
||||
@ -1885,6 +1946,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
`tf.Variable`: Pointer to the resized Embedding Module or the old Embedding Module if `new_num_tokens` is
|
||||
`None`
|
||||
"""
|
||||
# TODO (joao): flagged for replacement (by `_v2_get_resized_embeddings`) due to embeddings refactor
|
||||
old_embedding_dim = shape_list(old_embeddings)[1]
|
||||
init_range = getattr(self.config, "initializer_range", 0.02)
|
||||
embeddings_mask, current_embeddings = init_copy_embeddings(old_embeddings, new_num_tokens)
|
||||
@ -1900,6 +1962,42 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
|
||||
return new_embeddings
|
||||
|
||||
def _v2_get_resized_embeddings(
|
||||
self, old_embeddings: tf.keras.layers.Embedding, new_num_tokens: int
|
||||
) -> tf.keras.layers.Embedding:
|
||||
"""
|
||||
Build a resized Embedding layer from a provided Embedding layer. Increasing the size will add newly initialized
|
||||
vectors at the end. Reducing the size will remove vectors from the end.
|
||||
|
||||
Args:
|
||||
old_embeddings (`tf.keras.layers.Embedding`):
|
||||
Old embeddings to be resized.
|
||||
new_num_tokens (`int`, *optional*):
|
||||
New number of tokens in the embedding matrix.
|
||||
|
||||
Return:
|
||||
`tf.keras.layers.Embedding`: Resized Embedding layer.
|
||||
"""
|
||||
# Get a new (initialized) embeddings layer
|
||||
init_range = getattr(self.config, "initializer_range", 0.02)
|
||||
new_embeddings = tf.keras.layers.Embedding(
|
||||
input_dim=new_num_tokens,
|
||||
output_dim=old_embeddings.output_dim,
|
||||
embeddings_initializer=get_initializer(init_range),
|
||||
name=old_embeddings.embeddings.name[:-13], # exact same scoped name except "/embeddings:0"
|
||||
)
|
||||
new_embeddings(tf.constant([[0]]))
|
||||
|
||||
# Copy the old embeddings to the new embeddings
|
||||
if old_embeddings.input_dim >= new_num_tokens:
|
||||
init_embeddings = old_embeddings.embeddings[:new_num_tokens]
|
||||
else:
|
||||
init_embeddings = tf.concat(
|
||||
[old_embeddings.embeddings, new_embeddings.embeddings[old_embeddings.input_dim :]], axis=0
|
||||
)
|
||||
new_embeddings.embeddings.assign(init_embeddings)
|
||||
return new_embeddings
|
||||
|
||||
def prune_heads(self, heads_to_prune):
|
||||
"""
|
||||
Prunes heads of the base model.
|
||||
@ -2632,6 +2730,7 @@ class TFSharedEmbeddings(tf.keras.layers.Layer):
|
||||
kwargs:
|
||||
Additional keyword arguments passed along to the `__init__` of `tf.keras.layers.Layer`.
|
||||
"""
|
||||
# TODO (joao): flagged for delection due to embeddings refactor
|
||||
|
||||
def __init__(self, vocab_size: int, hidden_size: int, initializer_range: Optional[float] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
@ -2848,6 +2947,8 @@ class TFWrappedEmbeddings:
|
||||
saving/storing the correct weights
|
||||
"""
|
||||
|
||||
# TODO (joao): flagged for delection due to embeddings refactor
|
||||
|
||||
def __init__(self, layer, abs_scope_name=None):
|
||||
self._layer = layer
|
||||
self._abs_scope_name = abs_scope_name
|
||||
|
@ -35,8 +35,6 @@ from ...modeling_tf_utils import (
|
||||
TFCausalLanguageModelingLoss,
|
||||
TFModelInputType,
|
||||
TFPreTrainedModel,
|
||||
TFSharedEmbeddings,
|
||||
TFWrappedEmbeddings,
|
||||
keras_serializable,
|
||||
unpack_inputs,
|
||||
)
|
||||
@ -113,7 +111,7 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
|
||||
return (one_cst - expanded_mask) * LARGE_NEGATIVE
|
||||
|
||||
|
||||
class TFBartLearnedPositionalEmbedding(TFSharedEmbeddings):
|
||||
class TFBartLearnedPositionalEmbedding(tf.keras.layers.Embedding):
|
||||
"""
|
||||
This module learns positional embeddings up to a fixed maximum size.
|
||||
"""
|
||||
@ -136,7 +134,8 @@ class TFBartLearnedPositionalEmbedding(TFSharedEmbeddings):
|
||||
position_ids = tf.range(seq_len, delta=1, name="range")
|
||||
position_ids += past_key_values_length
|
||||
|
||||
return super().call(position_ids + self.offset)
|
||||
offset_dtype = position_ids.dtype if isinstance(position_ids, tf.Tensor) else tf.int32
|
||||
return super().call(position_ids + tf.constant(self.offset, dtype=offset_dtype))
|
||||
|
||||
|
||||
class TFBartAttention(tf.keras.layers.Layer):
|
||||
@ -667,7 +666,7 @@ class TFBartEncoder(tf.keras.layers.Layer):
|
||||
config: BartConfig
|
||||
"""
|
||||
|
||||
def __init__(self, config: BartConfig, embed_tokens: Optional[TFSharedEmbeddings] = None, **kwargs):
|
||||
def __init__(self, config: BartConfig, embed_tokens: Optional[tf.keras.layers.Embedding] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.config = config
|
||||
self.dropout = tf.keras.layers.Dropout(config.dropout)
|
||||
@ -685,12 +684,6 @@ class TFBartEncoder(tf.keras.layers.Layer):
|
||||
self.layers = [TFBartEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)]
|
||||
self.layernorm_embedding = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding")
|
||||
|
||||
def get_embed_tokens(self):
|
||||
return self.embed_tokens
|
||||
|
||||
def set_embed_tokens(self, embed_tokens):
|
||||
self.embed_tokens = embed_tokens
|
||||
|
||||
@unpack_inputs
|
||||
def call(
|
||||
self,
|
||||
@ -750,7 +743,8 @@ class TFBartEncoder(tf.keras.layers.Layer):
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
with tf.name_scope(self.embed_tokens.name + "/"):
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
|
||||
embed_pos = self.embed_positions(input_shape)
|
||||
hidden_states = inputs_embeds + embed_pos
|
||||
@ -820,7 +814,7 @@ class TFBartDecoder(tf.keras.layers.Layer):
|
||||
embed_tokens: output embedding
|
||||
"""
|
||||
|
||||
def __init__(self, config: BartConfig, embed_tokens: Optional[TFSharedEmbeddings] = None, **kwargs):
|
||||
def __init__(self, config: BartConfig, embed_tokens: Optional[tf.keras.layers.Embedding] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
@ -837,12 +831,6 @@ class TFBartDecoder(tf.keras.layers.Layer):
|
||||
|
||||
self.dropout = tf.keras.layers.Dropout(config.dropout)
|
||||
|
||||
def get_embed_tokens(self):
|
||||
return self.embed_tokens
|
||||
|
||||
def set_embed_tokens(self, embed_tokens):
|
||||
self.embed_tokens = embed_tokens
|
||||
|
||||
@unpack_inputs
|
||||
def call(
|
||||
self,
|
||||
@ -943,7 +931,8 @@ class TFBartDecoder(tf.keras.layers.Layer):
|
||||
positions = self.embed_positions(input_shape, position_ids=position_ids)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
with tf.name_scope(self.embed_tokens.name + "/"):
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
@ -1038,36 +1027,19 @@ class TFBartMainLayer(tf.keras.layers.Layer):
|
||||
def __init__(self, config: BartConfig, load_weight_prefix=None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.config = config
|
||||
self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, config.pad_token_id, name="model.shared")
|
||||
load_weight_prefix = "model.shared" if load_weight_prefix is None else load_weight_prefix
|
||||
self.shared = tf.keras.layers.Embedding(config.vocab_size, config.d_model, name=load_weight_prefix)
|
||||
|
||||
# set tf scope correctly
|
||||
if load_weight_prefix is None:
|
||||
load_weight_prefix = "model.shared"
|
||||
|
||||
with tf.compat.v1.variable_scope(load_weight_prefix) as shared_abs_scope_name:
|
||||
pass
|
||||
|
||||
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
|
||||
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
|
||||
embed_tokens.vocab_size = self.shared.vocab_size
|
||||
embed_tokens.hidden_size = self.shared.hidden_size
|
||||
|
||||
self.encoder = TFBartEncoder(config, embed_tokens, name="encoder")
|
||||
self.decoder = TFBartDecoder(config, embed_tokens, name="decoder")
|
||||
self.encoder = TFBartEncoder(config, self.shared, name="encoder")
|
||||
self.decoder = TFBartDecoder(config, self.shared, name="decoder")
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.shared
|
||||
|
||||
def set_input_embeddings(self, new_embeddings):
|
||||
self.shared.weight = new_embeddings
|
||||
self.shared.vocab_size = self.shared.weight.shape[0]
|
||||
# retrieve correct absolute scope for embed token wrapper
|
||||
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
|
||||
pass
|
||||
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
|
||||
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
|
||||
self.encoder.set_embed_tokens(embed_tokens)
|
||||
self.decoder.set_embed_tokens(embed_tokens)
|
||||
self.shared = new_embeddings
|
||||
self.encoder.embed_tokens = self.shared
|
||||
self.decoder.embed_tokens = self.shared
|
||||
|
||||
@unpack_inputs
|
||||
def call(
|
||||
@ -1273,11 +1245,7 @@ class BiasLayer(tf.keras.layers.Layer):
|
||||
BART_START_DOCSTRING,
|
||||
)
|
||||
class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageModelingLoss):
|
||||
_keys_to_ignore_on_load_unexpected = [
|
||||
r"model.encoder.embed_tokens.weight",
|
||||
r"model.decoder.embed_tokens.weight",
|
||||
]
|
||||
|
||||
_keys_to_ignore_on_load_missing = [r"final_logits_bias"]
|
||||
_requires_load_weight_prefix = True
|
||||
|
||||
def __init__(self, config, load_weight_prefix=None, *inputs, **kwargs):
|
||||
@ -1303,10 +1271,10 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
|
||||
self.set_input_embeddings(value)
|
||||
|
||||
def get_bias(self):
|
||||
return {"final_logits_bias": self.final_logits_bias}
|
||||
return {"final_logits_bias": self.bias_layer.bias}
|
||||
|
||||
def set_bias(self, value):
|
||||
self.final_logits_bias = value["final_logits_bias"]
|
||||
self.bias_layer.bias = value["final_logits_bias"]
|
||||
|
||||
@add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||
@ -1374,7 +1342,9 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
)
|
||||
lm_logits = self.model.shared(outputs[0], mode="linear")
|
||||
# TODO (joao): the line below is for models with tied embeddings. The previous TFBart had tied embeddings.
|
||||
# The PT Bart does not have tied embeddings. Untie the weights while keeping loading retrocompatibility.
|
||||
lm_logits = tf.matmul(outputs[0], self.model.shared.weights, transpose_b=True)
|
||||
lm_logits = self.bias_layer(lm_logits)
|
||||
masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits)
|
||||
|
||||
|
@ -137,7 +137,8 @@ class TFMBartLearnedPositionalEmbedding(TFSharedEmbeddings):
|
||||
position_ids = tf.range(seq_len, delta=1, name="range")
|
||||
position_ids += past_key_values_length
|
||||
|
||||
return super().call(position_ids + self.offset)
|
||||
offset_dtype = position_ids.dtype if isinstance(position_ids, tf.Tensor) else tf.int32
|
||||
return super().call(position_ids + tf.constant(self.offset, dtype=offset_dtype))
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with Bart->MBart
|
||||
|
@ -230,69 +230,6 @@ class TFBartModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestC
|
||||
name = model.get_bias()
|
||||
assert name is None
|
||||
|
||||
def test_resize_token_embeddings(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
def _get_word_embedding_weight(model, embedding_layer):
|
||||
if hasattr(embedding_layer, "weight"):
|
||||
return embedding_layer.weight
|
||||
else:
|
||||
# Here we build the word embeddings weights if not exists.
|
||||
# And then we retry to get the attribute once built.
|
||||
model(model.dummy_inputs)
|
||||
if hasattr(embedding_layer, "weight"):
|
||||
return embedding_layer.weight
|
||||
else:
|
||||
return None
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
for size in [config.vocab_size - 10, config.vocab_size + 10, None]:
|
||||
# build the embeddings
|
||||
model = model_class(config=config)
|
||||
old_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings())
|
||||
old_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings())
|
||||
old_final_logits_bias = model.get_bias()
|
||||
|
||||
# reshape the embeddings
|
||||
model.resize_token_embeddings(size)
|
||||
new_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings())
|
||||
new_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings())
|
||||
new_final_logits_bias = model.get_bias()
|
||||
|
||||
# check that the resized embeddings size matches the desired size.
|
||||
assert_size = size if size is not None else config.vocab_size
|
||||
|
||||
self.assertEqual(new_input_embeddings.shape[0], assert_size)
|
||||
|
||||
# check that weights remain the same after resizing
|
||||
models_equal = True
|
||||
for p1, p2 in zip(old_input_embeddings.value(), new_input_embeddings.value()):
|
||||
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
|
||||
models_equal = False
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
if old_output_embeddings is not None and new_output_embeddings is not None:
|
||||
self.assertEqual(new_output_embeddings.shape[0], assert_size)
|
||||
|
||||
models_equal = True
|
||||
for p1, p2 in zip(old_output_embeddings.value(), new_output_embeddings.value()):
|
||||
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
|
||||
models_equal = False
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
if old_final_logits_bias is not None and new_final_logits_bias is not None:
|
||||
old_final_logits_bias = old_final_logits_bias["final_logits_bias"]
|
||||
new_final_logits_bias = new_final_logits_bias["final_logits_bias"]
|
||||
self.assertEqual(new_final_logits_bias.shape[0], 1)
|
||||
self.assertEqual(new_final_logits_bias.shape[1], assert_size)
|
||||
|
||||
models_equal = True
|
||||
for old, new in zip(old_final_logits_bias.value(), new_final_logits_bias.value()):
|
||||
for p1, p2 in zip(old, new):
|
||||
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
|
||||
models_equal = False
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
@tooslow
|
||||
def test_saved_model_creation(self):
|
||||
pass
|
||||
@ -635,7 +572,7 @@ class FasterTFBartModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
def test_xsum_1_1_generation(self):
|
||||
model = self.xsum_1_1_model
|
||||
assert model.model.decoder.embed_tokens._layer == model.model.shared
|
||||
assert model.model.decoder.embed_tokens == model.model.shared
|
||||
ARTICLE = (
|
||||
"The Palestinian Authority officially became the 123rd member of the International Criminal Court on"
|
||||
" Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The"
|
||||
@ -685,7 +622,7 @@ class FasterTFBartModelIntegrationTests(unittest.TestCase):
|
||||
def test_xsum_1_1_xla_generation(self):
|
||||
# same test as above, but with `no_repeat_ngram_size=0` (not compatible with XLA) and XLA comparison enabled
|
||||
model = self.xsum_1_1_model
|
||||
assert model.model.decoder.embed_tokens._layer == model.model.shared
|
||||
assert model.model.decoder.embed_tokens == model.model.shared
|
||||
ARTICLE = (
|
||||
"The Palestinian Authority officially became the 123rd member of the International Criminal Court on"
|
||||
" Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The"
|
||||
|
@ -1144,30 +1144,20 @@ class TFModelTesterMixin:
|
||||
self.assert_outputs_same(output_for_dict_input, output_for_kw_input)
|
||||
|
||||
def test_resize_token_embeddings(self):
|
||||
# TODO (joao): after the embeddings refactor is complete, rework this test so as to rely exclusively on
|
||||
# tf.keras.layers.Embedding
|
||||
|
||||
if not self.test_resize_embeddings:
|
||||
return
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
def _get_word_embedding_weight(model, embedding_layer):
|
||||
embeds = getattr(embedding_layer, "weight", None)
|
||||
if embeds is not None:
|
||||
return embeds
|
||||
|
||||
embeds = getattr(embedding_layer, "decoder", None)
|
||||
if embeds is not None:
|
||||
return embeds
|
||||
|
||||
model(model.dummy_inputs)
|
||||
|
||||
embeds = getattr(embedding_layer, "weight", None)
|
||||
if embeds is not None:
|
||||
return embeds
|
||||
|
||||
embeds = getattr(embedding_layer, "decoder", None)
|
||||
if embeds is not None:
|
||||
return embeds
|
||||
|
||||
return None
|
||||
if isinstance(embedding_layer, tf.keras.layers.Embedding):
|
||||
# builds the embeddings layer
|
||||
model(model.dummy_inputs)
|
||||
return embedding_layer.embeddings
|
||||
else:
|
||||
return model._get_word_embedding_weight(embedding_layer)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
for size in [config.vocab_size - 10, config.vocab_size + 10, None]:
|
||||
@ -1195,10 +1185,10 @@ class TFModelTesterMixin:
|
||||
|
||||
if old_bias is not None and new_bias is not None:
|
||||
for old_weight, new_weight in zip(old_bias.values(), new_bias.values()):
|
||||
self.assertEqual(new_weight.shape[0], assert_size)
|
||||
self.assertEqual(new_weight.shape[-1], assert_size)
|
||||
|
||||
models_equal = True
|
||||
for p1, p2 in zip(old_weight.value(), new_weight.value()):
|
||||
for p1, p2 in zip(tf.squeeze(old_weight), tf.squeeze(new_weight)):
|
||||
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
|
||||
models_equal = False
|
||||
self.assertTrue(models_equal)
|
||||
|
Loading…
Reference in New Issue
Block a user