mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
🚨🚨🚨 TF: Remove TFWrappedEmbeddings
(breaking: TF embedding initialization updated for encoder-decoder models) (#19263)
* added test * correct embedding init * some changes in blenderbot (incomplete) * update blenderbot (diff to be used as reference) * update blenderbot_small * update LED * update marian * update T5 and remove TFWrappedEmbeddings * nullcontext() -> ContextManagers() * fix embedding init
This commit is contained in:
parent
8e4ee28e34
commit
462cd641d9
@ -3038,36 +3038,3 @@ def get_initializer(initializer_range: float = 0.02) -> tf.initializers.Truncate
|
||||
`tf.initializers.TruncatedNormal`: The truncated normal initializer.
|
||||
"""
|
||||
return tf.keras.initializers.TruncatedNormal(stddev=initializer_range)
|
||||
|
||||
|
||||
class TFWrappedEmbeddings:
|
||||
"""
|
||||
this class wraps a the TFSharedEmbeddingTokens layer into a python 'no-keras-layer' class to avoid problem with
|
||||
weight restoring. Also it makes sure that the layer is called from the correct scope to avoid problem with
|
||||
saving/storing the correct weights
|
||||
"""
|
||||
|
||||
# TODO (joao): flagged for delection due to embeddings refactor
|
||||
|
||||
def __init__(self, layer, abs_scope_name=None):
|
||||
self._layer = layer
|
||||
self._abs_scope_name = abs_scope_name
|
||||
self.vocab_size = self._layer.vocab_size
|
||||
|
||||
def call(self, inputs, mode="embedding"):
|
||||
if self._abs_scope_name is None:
|
||||
return self._layer.call(inputs, mode)
|
||||
|
||||
# if an abs scope name is given to the embedding variable, call variable from absolute scope
|
||||
with tf.compat.v1.variable_scope(self._abs_scope_name, auxiliary_name_scope=False) as abs_scope_name:
|
||||
with tf.name_scope(abs_scope_name.original_name_scope):
|
||||
return self._layer.call(inputs, mode)
|
||||
|
||||
def __call__(self, inputs, mode="embedding"):
|
||||
if self._abs_scope_name is None:
|
||||
return self._layer(inputs, mode)
|
||||
|
||||
# if an abs scope name is given to the embedding variable, call variable from absolute scope
|
||||
with tf.compat.v1.variable_scope(self._abs_scope_name, auxiliary_name_scope=False) as abs_scope_name:
|
||||
with tf.name_scope(abs_scope_name.original_name_scope):
|
||||
return self._layer(inputs, mode)
|
||||
|
@ -16,7 +16,6 @@
|
||||
|
||||
|
||||
import random
|
||||
from contextlib import nullcontext
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@ -41,6 +40,7 @@ from ...modeling_tf_utils import (
|
||||
)
|
||||
from ...tf_utils import shape_list, stable_softmax
|
||||
from ...utils import (
|
||||
ContextManagers,
|
||||
add_code_sample_docstrings,
|
||||
add_end_docstrings,
|
||||
add_start_docstrings,
|
||||
@ -741,11 +741,10 @@ class TFBartEncoder(tf.keras.layers.Layer):
|
||||
# scope, so that its weights are registered with the desired name for loading/storing. When `tf.name_scope`
|
||||
# is used with a name ending in `/`, that name replaces the current name scope.
|
||||
# (embeddings with tf.name_scope: self.embed_tokens.load_weight_prefix/self.embed_tokens.name/embeddings:0)
|
||||
context = []
|
||||
if hasattr(self.embed_tokens, "load_weight_prefix"):
|
||||
context_manager = tf.name_scope(self.embed_tokens.load_weight_prefix + "/")
|
||||
else:
|
||||
context_manager = nullcontext()
|
||||
with context_manager:
|
||||
context.append(tf.name_scope(self.embed_tokens.load_weight_prefix + "/"))
|
||||
with ContextManagers(context):
|
||||
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
|
||||
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
|
||||
tf.debugging.assert_less(
|
||||
@ -945,11 +944,10 @@ class TFBartDecoder(tf.keras.layers.Layer):
|
||||
# scope, so that its weights are registered with the desired name for loading/storing. When `tf.name_scope`
|
||||
# is used with a name ending in `/`, that name replaces the current name scope.
|
||||
# (embeddings with tf.name_scope: self.embed_tokens.load_weight_prefix/self.embed_tokens.name/embeddings:0)
|
||||
context = []
|
||||
if hasattr(self.embed_tokens, "load_weight_prefix"):
|
||||
context_manager = tf.name_scope(self.embed_tokens.load_weight_prefix + "/")
|
||||
else:
|
||||
context_manager = nullcontext()
|
||||
with context_manager:
|
||||
context.append(tf.name_scope(self.embed_tokens.load_weight_prefix + "/"))
|
||||
with ContextManagers(context):
|
||||
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
|
||||
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
|
||||
tf.debugging.assert_less(
|
||||
@ -1378,8 +1376,6 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
)
|
||||
# TODO (joao): the line below is for models with tied embeddings. The previous TFBart had tied embeddings.
|
||||
# The PT Bart does not have tied embeddings. Untie the weights while keeping loading retrocompatibility.
|
||||
lm_logits = tf.matmul(outputs[0], self.model.shared.weights, transpose_b=True)
|
||||
lm_logits = self.bias_layer(lm_logits)
|
||||
masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits)
|
||||
|
@ -35,13 +35,12 @@ from ...modeling_tf_utils import (
|
||||
DUMMY_INPUTS,
|
||||
TFCausalLanguageModelingLoss,
|
||||
TFPreTrainedModel,
|
||||
TFSharedEmbeddings,
|
||||
TFWrappedEmbeddings,
|
||||
keras_serializable,
|
||||
unpack_inputs,
|
||||
)
|
||||
from ...tf_utils import shape_list, stable_softmax
|
||||
from ...utils import (
|
||||
ContextManagers,
|
||||
add_code_sample_docstrings,
|
||||
add_end_docstrings,
|
||||
add_start_docstrings,
|
||||
@ -119,7 +118,7 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
|
||||
return (one_cst - expanded_mask) * LARGE_NEGATIVE
|
||||
|
||||
|
||||
class TFBlenderbotLearnedPositionalEmbedding(TFSharedEmbeddings):
|
||||
class TFBlenderbotLearnedPositionalEmbedding(tf.keras.layers.Embedding):
|
||||
"""
|
||||
This module learns positional embeddings up to a fixed maximum size.
|
||||
"""
|
||||
@ -133,8 +132,10 @@ class TFBlenderbotLearnedPositionalEmbedding(TFSharedEmbeddings):
|
||||
"""Input is expected to be of size [bsz x seqlen]."""
|
||||
if position_ids is None:
|
||||
seq_len = input_shape[1]
|
||||
position_ids = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range")
|
||||
return super().call(position_ids)
|
||||
position_ids = tf.range(seq_len, delta=1, name="range")
|
||||
position_ids += past_key_values_length
|
||||
|
||||
return super().call(tf.cast(position_ids, dtype=tf.int32))
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with Bart->Blenderbot
|
||||
@ -638,7 +639,7 @@ class TFBlenderbotEncoder(tf.keras.layers.Layer):
|
||||
config: BlenderbotConfig
|
||||
"""
|
||||
|
||||
def __init__(self, config: BlenderbotConfig, embed_tokens: Optional[TFSharedEmbeddings] = None, **kwargs):
|
||||
def __init__(self, config: BlenderbotConfig, embed_tokens: Optional[tf.keras.layers.Embedding] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.config = config
|
||||
self.dropout = tf.keras.layers.Dropout(config.dropout)
|
||||
@ -726,17 +727,25 @@ class TFBlenderbotEncoder(tf.keras.layers.Layer):
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if inputs_embeds is None:
|
||||
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
|
||||
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
|
||||
tf.debugging.assert_less(
|
||||
input_ids,
|
||||
tf.cast(self.embed_tokens.vocab_size, dtype=input_ids.dtype),
|
||||
message=(
|
||||
"input_ids must be smaller than the embedding layer's input dimension (got"
|
||||
f" {tf.math.reduce_max(input_ids)} >= {self.embed_tokens.vocab_size})"
|
||||
),
|
||||
)
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
# if `self.embed_tokens.load_weight_prefix` is set, runs the embedding operation with the correct name
|
||||
# scope, so that its weights are registered with the desired name for loading/storing. When `tf.name_scope`
|
||||
# is used with a name ending in `/`, that name replaces the current name scope.
|
||||
# (embeddings with tf.name_scope: self.embed_tokens.load_weight_prefix/self.embed_tokens.name/embeddings:0)
|
||||
context = []
|
||||
if hasattr(self.embed_tokens, "load_weight_prefix"):
|
||||
context.append(tf.name_scope(self.embed_tokens.load_weight_prefix + "/"))
|
||||
with ContextManagers(context):
|
||||
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
|
||||
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
|
||||
tf.debugging.assert_less(
|
||||
input_ids,
|
||||
tf.cast(self.embed_tokens.input_dim, dtype=input_ids.dtype),
|
||||
message=(
|
||||
"input_ids must be smaller than the embedding layer's input dimension (got"
|
||||
f" {tf.math.reduce_max(input_ids)} >= {self.embed_tokens.input_dim})"
|
||||
),
|
||||
)
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
|
||||
embed_pos = self.embed_positions(input_shape)
|
||||
hidden_states = inputs_embeds + embed_pos
|
||||
@ -805,7 +814,7 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
|
||||
embed_tokens: output embedding
|
||||
"""
|
||||
|
||||
def __init__(self, config: BlenderbotConfig, embed_tokens: Optional[TFSharedEmbeddings] = None, **kwargs):
|
||||
def __init__(self, config: BlenderbotConfig, embed_tokens: Optional[tf.keras.layers.Embedding] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
@ -933,17 +942,21 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
|
||||
positions = self.embed_positions(input_shape, position_ids=position_ids)
|
||||
|
||||
if inputs_embeds is None:
|
||||
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
|
||||
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
|
||||
tf.debugging.assert_less(
|
||||
input_ids,
|
||||
tf.cast(self.embed_tokens.vocab_size, dtype=input_ids.dtype),
|
||||
message=(
|
||||
"input_ids must be smaller than the embedding layer's input dimension (got"
|
||||
f" {tf.math.reduce_max(input_ids)} >= {self.embed_tokens.vocab_size})"
|
||||
),
|
||||
)
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
context = []
|
||||
if hasattr(self.embed_tokens, "load_weight_prefix"):
|
||||
context.append(tf.name_scope(self.embed_tokens.load_weight_prefix + "/"))
|
||||
with ContextManagers(context):
|
||||
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
|
||||
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
|
||||
tf.debugging.assert_less(
|
||||
input_ids,
|
||||
tf.cast(self.embed_tokens.input_dim, dtype=input_ids.dtype),
|
||||
message=(
|
||||
"input_ids must be smaller than the embedding layer's input dimension (got"
|
||||
f" {tf.math.reduce_max(input_ids)} >= {self.embed_tokens.input_dim})"
|
||||
),
|
||||
)
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
@ -1037,32 +1050,25 @@ class TFBlenderbotMainLayer(tf.keras.layers.Layer):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.config = config
|
||||
self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, config.pad_token_id, name="model.shared")
|
||||
self.shared = tf.keras.layers.Embedding(
|
||||
input_dim=config.vocab_size,
|
||||
output_dim=config.d_model,
|
||||
embeddings_initializer=tf.keras.initializers.TruncatedNormal(stddev=self.config.init_std),
|
||||
name="model.shared",
|
||||
)
|
||||
# Additional attribute to specify the expected name scope of the layer (for loading/storing weights)
|
||||
self.shared.load_weight_prefix = "model.shared"
|
||||
|
||||
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
|
||||
pass
|
||||
|
||||
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
|
||||
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
|
||||
embed_tokens.vocab_size = self.shared.vocab_size
|
||||
embed_tokens.hidden_size = self.shared.hidden_size
|
||||
|
||||
self.encoder = TFBlenderbotEncoder(config, embed_tokens, name="encoder")
|
||||
self.decoder = TFBlenderbotDecoder(config, embed_tokens, name="decoder")
|
||||
self.encoder = TFBlenderbotEncoder(config, self.shared, name="encoder")
|
||||
self.decoder = TFBlenderbotDecoder(config, self.shared, name="decoder")
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.shared
|
||||
|
||||
def set_input_embeddings(self, new_embeddings):
|
||||
self.shared.weight = new_embeddings
|
||||
self.shared.vocab_size = self.shared.weight.shape[0]
|
||||
# retrieve correct absolute scope for embed token wrapper
|
||||
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
|
||||
pass
|
||||
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
|
||||
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
|
||||
self.encoder.set_embed_tokens(embed_tokens)
|
||||
self.decoder.set_embed_tokens(embed_tokens)
|
||||
self.shared = new_embeddings
|
||||
self.encoder.embed_tokens = self.shared
|
||||
self.decoder.embed_tokens = self.shared
|
||||
|
||||
@unpack_inputs
|
||||
def call(
|
||||
@ -1284,7 +1290,6 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal
|
||||
self.bias_layer = BiasLayer(
|
||||
name="final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False
|
||||
)
|
||||
self.final_logits_bias = self.bias_layer.bias # alias to keep the same interface with PT
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model.decoder
|
||||
@ -1299,10 +1304,15 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal
|
||||
self.set_input_embeddings(value)
|
||||
|
||||
def get_bias(self):
|
||||
return {"final_logits_bias": self.final_logits_bias}
|
||||
return {"final_logits_bias": self.bias_layer.bias}
|
||||
|
||||
def set_bias(self, value):
|
||||
self.final_logits_bias = value["final_logits_bias"]
|
||||
# Replaces the existing layers containing bias for correct (de)serialization.
|
||||
vocab_size = value["final_logits_bias"].shape[-1]
|
||||
self.bias_layer = BiasLayer(
|
||||
name="final_logits_bias", shape=[1, vocab_size], initializer="zeros", trainable=False
|
||||
)
|
||||
self.bias_layer.bias.assign(value["final_logits_bias"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
|
||||
@ -1385,7 +1395,7 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
)
|
||||
lm_logits = self.model.shared(outputs[0], mode="linear")
|
||||
lm_logits = tf.matmul(outputs[0], self.model.shared.weights, transpose_b=True)
|
||||
lm_logits = self.bias_layer(lm_logits)
|
||||
masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits)
|
||||
|
||||
|
@ -34,13 +34,12 @@ from ...modeling_tf_utils import (
|
||||
DUMMY_INPUTS,
|
||||
TFCausalLanguageModelingLoss,
|
||||
TFPreTrainedModel,
|
||||
TFSharedEmbeddings,
|
||||
TFWrappedEmbeddings,
|
||||
keras_serializable,
|
||||
unpack_inputs,
|
||||
)
|
||||
from ...tf_utils import shape_list, stable_softmax
|
||||
from ...utils import (
|
||||
ContextManagers,
|
||||
add_code_sample_docstrings,
|
||||
add_end_docstrings,
|
||||
add_start_docstrings,
|
||||
@ -119,7 +118,7 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
|
||||
|
||||
|
||||
# Copied from transformers.models.blenderbot.modeling_tf_blenderbot.TFBlenderbotLearnedPositionalEmbedding with Blenderbot->BlenderbotSmall
|
||||
class TFBlenderbotSmallLearnedPositionalEmbedding(TFSharedEmbeddings):
|
||||
class TFBlenderbotSmallLearnedPositionalEmbedding(tf.keras.layers.Embedding):
|
||||
"""
|
||||
This module learns positional embeddings up to a fixed maximum size.
|
||||
"""
|
||||
@ -133,8 +132,10 @@ class TFBlenderbotSmallLearnedPositionalEmbedding(TFSharedEmbeddings):
|
||||
"""Input is expected to be of size [bsz x seqlen]."""
|
||||
if position_ids is None:
|
||||
seq_len = input_shape[1]
|
||||
position_ids = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range")
|
||||
return super().call(position_ids)
|
||||
position_ids = tf.range(seq_len, delta=1, name="range")
|
||||
position_ids += past_key_values_length
|
||||
|
||||
return super().call(tf.cast(position_ids, dtype=tf.int32))
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with Bart->BlenderbotSmall
|
||||
@ -643,7 +644,9 @@ class TFBlenderbotSmallEncoder(tf.keras.layers.Layer):
|
||||
config: BlenderbotSmallConfig
|
||||
"""
|
||||
|
||||
def __init__(self, config: BlenderbotSmallConfig, embed_tokens: Optional[TFSharedEmbeddings] = None, **kwargs):
|
||||
def __init__(
|
||||
self, config: BlenderbotSmallConfig, embed_tokens: Optional[tf.keras.layers.Embedding] = None, **kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.config = config
|
||||
self.dropout = tf.keras.layers.Dropout(config.dropout)
|
||||
@ -731,17 +734,25 @@ class TFBlenderbotSmallEncoder(tf.keras.layers.Layer):
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if inputs_embeds is None:
|
||||
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
|
||||
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
|
||||
tf.debugging.assert_less(
|
||||
input_ids,
|
||||
tf.cast(self.embed_tokens.vocab_size, dtype=input_ids.dtype),
|
||||
message=(
|
||||
"input_ids must be smaller than the embedding layer's input dimension (got"
|
||||
f" {tf.math.reduce_max(input_ids)} >= {self.embed_tokens.vocab_size})"
|
||||
),
|
||||
)
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
# if `self.embed_tokens.load_weight_prefix` is set, runs the embedding operation with the correct name
|
||||
# scope, so that its weights are registered with the desired name for loading/storing. When `tf.name_scope`
|
||||
# is used with a name ending in `/`, that name replaces the current name scope.
|
||||
# (embeddings with tf.name_scope: self.embed_tokens.load_weight_prefix/self.embed_tokens.name/embeddings:0)
|
||||
context = []
|
||||
if hasattr(self.embed_tokens, "load_weight_prefix"):
|
||||
context.append(tf.name_scope(self.embed_tokens.load_weight_prefix + "/"))
|
||||
with ContextManagers(context):
|
||||
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
|
||||
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
|
||||
tf.debugging.assert_less(
|
||||
input_ids,
|
||||
tf.cast(self.embed_tokens.input_dim, dtype=input_ids.dtype),
|
||||
message=(
|
||||
"input_ids must be smaller than the embedding layer's input dimension (got"
|
||||
f" {tf.math.reduce_max(input_ids)} >= {self.embed_tokens.input_dim})"
|
||||
),
|
||||
)
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
|
||||
embed_pos = self.embed_positions(input_shape)
|
||||
hidden_states = inputs_embeds + embed_pos
|
||||
@ -809,7 +820,9 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
|
||||
embed_tokens: output embedding
|
||||
"""
|
||||
|
||||
def __init__(self, config: BlenderbotSmallConfig, embed_tokens: Optional[TFSharedEmbeddings] = None, **kwargs):
|
||||
def __init__(
|
||||
self, config: BlenderbotSmallConfig, embed_tokens: Optional[tf.keras.layers.Embedding] = None, **kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
@ -931,17 +944,25 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
|
||||
past_key_values_length = shape_list(past_key_values[0][0])[2] if past_key_values is not None else 0
|
||||
|
||||
if inputs_embeds is None:
|
||||
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
|
||||
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
|
||||
tf.debugging.assert_less(
|
||||
input_ids,
|
||||
tf.cast(self.embed_tokens.vocab_size, dtype=input_ids.dtype),
|
||||
message=(
|
||||
"input_ids must be smaller than the embedding layer's input dimension (got"
|
||||
f" {tf.math.reduce_max(input_ids)} >= {self.embed_tokens.vocab_size})"
|
||||
),
|
||||
)
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
# if `self.embed_tokens.load_weight_prefix` is set, runs the embedding operation with the correct name
|
||||
# scope, so that its weights are registered with the desired name for loading/storing. When `tf.name_scope`
|
||||
# is used with a name ending in `/`, that name replaces the current name scope.
|
||||
# (embeddings with tf.name_scope: self.embed_tokens.load_weight_prefix/self.embed_tokens.name/embeddings:0)
|
||||
context = []
|
||||
if hasattr(self.embed_tokens, "load_weight_prefix"):
|
||||
context.append(tf.name_scope(self.embed_tokens.load_weight_prefix + "/"))
|
||||
with ContextManagers(context):
|
||||
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
|
||||
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
|
||||
tf.debugging.assert_less(
|
||||
input_ids,
|
||||
tf.cast(self.embed_tokens.input_dim, dtype=input_ids.dtype),
|
||||
message=(
|
||||
"input_ids must be smaller than the embedding layer's input dimension (got"
|
||||
f" {tf.math.reduce_max(input_ids)} >= {self.embed_tokens.input_dim})"
|
||||
),
|
||||
)
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
if input_shape[-1] > 1:
|
||||
@ -1038,32 +1059,25 @@ class TFBlenderbotSmallMainLayer(tf.keras.layers.Layer):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.config = config
|
||||
self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, config.pad_token_id, name="model.shared")
|
||||
self.shared = tf.keras.layers.Embedding(
|
||||
input_dim=config.vocab_size,
|
||||
output_dim=config.d_model,
|
||||
embeddings_initializer=tf.keras.initializers.TruncatedNormal(stddev=self.config.init_std),
|
||||
name="model.shared",
|
||||
)
|
||||
# Additional attribute to specify the expected name scope of the layer (for loading/storing weights)
|
||||
self.shared.load_weight_prefix = "model.shared"
|
||||
|
||||
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
|
||||
pass
|
||||
|
||||
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
|
||||
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
|
||||
embed_tokens.vocab_size = self.shared.vocab_size
|
||||
embed_tokens.hidden_size = self.shared.hidden_size
|
||||
|
||||
self.encoder = TFBlenderbotSmallEncoder(config, embed_tokens, name="encoder")
|
||||
self.decoder = TFBlenderbotSmallDecoder(config, embed_tokens, name="decoder")
|
||||
self.encoder = TFBlenderbotSmallEncoder(config, self.shared, name="encoder")
|
||||
self.decoder = TFBlenderbotSmallDecoder(config, self.shared, name="decoder")
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.shared
|
||||
|
||||
def set_input_embeddings(self, new_embeddings):
|
||||
self.shared.weight = new_embeddings
|
||||
self.shared.vocab_size = self.shared.weight.shape[0]
|
||||
# retrieve correct absolute scope for embed token wrapper
|
||||
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
|
||||
pass
|
||||
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
|
||||
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
|
||||
self.encoder.set_embed_tokens(embed_tokens)
|
||||
self.decoder.set_embed_tokens(embed_tokens)
|
||||
self.shared = new_embeddings
|
||||
self.encoder.embed_tokens = self.shared
|
||||
self.decoder.embed_tokens = self.shared
|
||||
|
||||
@unpack_inputs
|
||||
def call(
|
||||
@ -1271,7 +1285,6 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
|
||||
self.bias_layer = BiasLayer(
|
||||
name="final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False
|
||||
)
|
||||
self.final_logits_bias = self.bias_layer.bias # alias to keep the same interface with PT
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model.decoder
|
||||
@ -1286,10 +1299,15 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
|
||||
self.set_input_embeddings(value)
|
||||
|
||||
def get_bias(self):
|
||||
return {"final_logits_bias": self.final_logits_bias}
|
||||
return {"final_logits_bias": self.bias_layer.bias}
|
||||
|
||||
def set_bias(self, value):
|
||||
self.final_logits_bias = value["final_logits_bias"]
|
||||
# Replaces the existing layers containing bias for correct (de)serialization.
|
||||
vocab_size = value["final_logits_bias"].shape[-1]
|
||||
self.bias_layer = BiasLayer(
|
||||
name="final_logits_bias", shape=[1, vocab_size], initializer="zeros", trainable=False
|
||||
)
|
||||
self.bias_layer.bias.assign(value["final_logits_bias"])
|
||||
|
||||
@unpack_inputs
|
||||
@add_start_docstrings_to_model_forward(BLENDERBOT_SMALL_INPUTS_DOCSTRING)
|
||||
@ -1357,7 +1375,7 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
)
|
||||
lm_logits = self.model.shared(outputs[0], mode="linear")
|
||||
lm_logits = tf.matmul(outputs[0], self.model.shared.weights, transpose_b=True)
|
||||
lm_logits = self.bias_layer(lm_logits)
|
||||
masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits)
|
||||
|
||||
|
@ -29,14 +29,13 @@ from ...modeling_tf_outputs import TFBaseModelOutputWithPastAndCrossAttentions
|
||||
from ...modeling_tf_utils import (
|
||||
TFModelInputType,
|
||||
TFPreTrainedModel,
|
||||
TFSharedEmbeddings,
|
||||
TFWrappedEmbeddings,
|
||||
get_initializer,
|
||||
keras_serializable,
|
||||
unpack_inputs,
|
||||
)
|
||||
from ...tf_utils import shape_list, stable_softmax
|
||||
from ...utils import (
|
||||
ContextManagers,
|
||||
ModelOutput,
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
@ -114,7 +113,7 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
|
||||
return (one_cst - expanded_mask) * LARGE_NEGATIVE
|
||||
|
||||
|
||||
class TFLEDLearnedPositionalEmbedding(TFSharedEmbeddings):
|
||||
class TFLEDLearnedPositionalEmbedding(tf.keras.layers.Embedding):
|
||||
"""
|
||||
This module learns positional embeddings up to a fixed maximum size.
|
||||
"""
|
||||
@ -124,10 +123,11 @@ class TFLEDLearnedPositionalEmbedding(TFSharedEmbeddings):
|
||||
|
||||
def call(self, input_shape: tf.TensorShape, past_key_values_length: int = 0):
|
||||
"""Input is expected to be of size [bsz x seqlen]."""
|
||||
bsz, seq_len = input_shape[:2]
|
||||
seq_len = input_shape[1]
|
||||
position_ids = tf.range(seq_len, delta=1, name="range")
|
||||
position_ids += past_key_values_length
|
||||
|
||||
positions = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range")
|
||||
return super().call(positions)
|
||||
return super().call(tf.cast(position_ids, dtype=tf.int32))
|
||||
|
||||
|
||||
# Copied from transformers.models.longformer.modeling_tf_longformer.TFLongformerSelfAttention with TFLongformer->TFLEDEncoder
|
||||
@ -1650,7 +1650,7 @@ class TFLEDEncoder(tf.keras.layers.Layer):
|
||||
config: LEDConfig
|
||||
"""
|
||||
|
||||
def __init__(self, config: LEDConfig, embed_tokens: Optional[TFSharedEmbeddings] = None, **kwargs):
|
||||
def __init__(self, config: LEDConfig, embed_tokens: Optional[tf.keras.layers.Embedding] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.config = config
|
||||
self.dropout = tf.keras.layers.Dropout(config.dropout)
|
||||
@ -1739,17 +1739,25 @@ class TFLEDEncoder(tf.keras.layers.Layer):
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = shape_list(input_ids)
|
||||
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
|
||||
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
|
||||
tf.debugging.assert_less(
|
||||
input_ids,
|
||||
tf.cast(self.embed_tokens.vocab_size, dtype=input_ids.dtype),
|
||||
message=(
|
||||
"input_ids must be smaller than the embedding layer's input dimension (got"
|
||||
f" {tf.math.reduce_max(input_ids)} >= {self.embed_tokens.vocab_size})"
|
||||
),
|
||||
)
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
# if `self.embed_tokens.load_weight_prefix` is set, runs the embedding operation with the correct name
|
||||
# scope, so that its weights are registered with the desired name for loading/storing. When `tf.name_scope`
|
||||
# is used with a name ending in `/`, that name replaces the current name scope.
|
||||
# (embeddings with tf.name_scope: self.embed_tokens.load_weight_prefix/self.embed_tokens.name/embeddings:0)
|
||||
context = []
|
||||
if hasattr(self.embed_tokens, "load_weight_prefix"):
|
||||
context.append(tf.name_scope(self.embed_tokens.load_weight_prefix + "/"))
|
||||
with ContextManagers(context):
|
||||
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
|
||||
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
|
||||
tf.debugging.assert_less(
|
||||
input_ids,
|
||||
tf.cast(self.embed_tokens.input_dim, dtype=input_ids.dtype),
|
||||
message=(
|
||||
"input_ids must be smaller than the embedding layer's input dimension (got"
|
||||
f" {tf.math.reduce_max(input_ids)} >= {self.embed_tokens.input_dim})"
|
||||
),
|
||||
)
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = shape_list(inputs_embeds)[:-1]
|
||||
else:
|
||||
@ -1917,7 +1925,7 @@ class TFLEDDecoder(tf.keras.layers.Layer):
|
||||
embed_tokens: output embedding
|
||||
"""
|
||||
|
||||
def __init__(self, config: LEDConfig, embed_tokens: Optional[TFSharedEmbeddings] = None, **kwargs):
|
||||
def __init__(self, config: LEDConfig, embed_tokens: Optional[tf.keras.layers.Embedding] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
@ -2024,17 +2032,25 @@ class TFLEDDecoder(tf.keras.layers.Layer):
|
||||
positions = self.embed_positions(input_shape, past_key_values_length)
|
||||
|
||||
if inputs_embeds is None:
|
||||
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
|
||||
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
|
||||
tf.debugging.assert_less(
|
||||
input_ids,
|
||||
tf.cast(self.embed_tokens.vocab_size, dtype=input_ids.dtype),
|
||||
message=(
|
||||
"input_ids must be smaller than the embedding layer's input dimension (got"
|
||||
f" {tf.math.reduce_max(input_ids)} >= {self.embed_tokens.vocab_size})"
|
||||
),
|
||||
)
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
# if `self.embed_tokens.load_weight_prefix` is set, runs the embedding operation with the correct name
|
||||
# scope, so that its weights are registered with the desired name for loading/storing. When `tf.name_scope`
|
||||
# is used with a name ending in `/`, that name replaces the current name scope.
|
||||
# (embeddings with tf.name_scope: self.embed_tokens.load_weight_prefix/self.embed_tokens.name/embeddings:0)
|
||||
context = []
|
||||
if hasattr(self.embed_tokens, "load_weight_prefix"):
|
||||
context.append(tf.name_scope(self.embed_tokens.load_weight_prefix + "/"))
|
||||
with ContextManagers(context):
|
||||
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
|
||||
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
|
||||
tf.debugging.assert_less(
|
||||
input_ids,
|
||||
tf.cast(self.embed_tokens.input_dim, dtype=input_ids.dtype),
|
||||
message=(
|
||||
"input_ids must be smaller than the embedding layer's input dimension (got"
|
||||
f" {tf.math.reduce_max(input_ids)} >= {self.embed_tokens.input_dim})"
|
||||
),
|
||||
)
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
@ -2134,32 +2150,25 @@ class TFLEDMainLayer(tf.keras.layers.Layer):
|
||||
def __init__(self, config: LEDConfig, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.config = config
|
||||
self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, config.pad_token_id, name="led.shared")
|
||||
self.shared = tf.keras.layers.Embedding(
|
||||
input_dim=config.vocab_size,
|
||||
output_dim=config.d_model,
|
||||
embeddings_initializer=tf.keras.initializers.TruncatedNormal(stddev=self.config.init_std),
|
||||
name="led.shared",
|
||||
)
|
||||
# Additional attribute to specify the expected name scope of the layer (for loading/storing weights)
|
||||
self.shared.load_weight_prefix = "led.shared"
|
||||
|
||||
with tf.compat.v1.variable_scope("led.shared") as shared_abs_scope_name:
|
||||
pass
|
||||
|
||||
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
|
||||
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
|
||||
embed_tokens.vocab_size = self.shared.vocab_size
|
||||
embed_tokens.hidden_size = self.shared.hidden_size
|
||||
|
||||
self.encoder = TFLEDEncoder(config, embed_tokens, name="encoder")
|
||||
self.decoder = TFLEDDecoder(config, embed_tokens, name="decoder")
|
||||
self.encoder = TFLEDEncoder(config, self.shared, name="encoder")
|
||||
self.decoder = TFLEDDecoder(config, self.shared, name="decoder")
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.shared
|
||||
|
||||
def set_input_embeddings(self, new_embeddings):
|
||||
self.shared.weight = new_embeddings
|
||||
self.shared.vocab_size = self.shared.weight.shape[0]
|
||||
# retrieve correct absolute scope for embed token wrapper
|
||||
with tf.compat.v1.variable_scope("led.shared") as shared_abs_scope_name:
|
||||
pass
|
||||
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
|
||||
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
|
||||
self.encoder.set_embed_tokens(embed_tokens)
|
||||
self.decoder.set_embed_tokens(embed_tokens)
|
||||
self.shared = new_embeddings
|
||||
self.encoder.embed_tokens = self.shared
|
||||
self.decoder.embed_tokens = self.shared
|
||||
|
||||
@unpack_inputs
|
||||
def call(
|
||||
@ -2365,7 +2374,7 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
|
||||
self.bias_layer = BiasLayer(
|
||||
name="final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False
|
||||
)
|
||||
self.final_logits_bias = self.bias_layer.bias # alias to keep the same interface with PT
|
||||
|
||||
# TODO (Joao): investigate why LED has numerical issues in XLA generate
|
||||
self.supports_xla_generation = False
|
||||
|
||||
@ -2376,10 +2385,15 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
|
||||
return self.led.encoder
|
||||
|
||||
def get_bias(self):
|
||||
return {"final_logits_bias": self.final_logits_bias}
|
||||
return {"final_logits_bias": self.bias_layer.bias}
|
||||
|
||||
def set_bias(self, value):
|
||||
self.final_logits_bias = value["final_logits_bias"]
|
||||
# Replaces the existing layers containing bias for correct (de)serialization.
|
||||
vocab_size = value["final_logits_bias"].shape[-1]
|
||||
self.bias_layer = BiasLayer(
|
||||
name="final_logits_bias", shape=[1, vocab_size], initializer="zeros", trainable=False
|
||||
)
|
||||
self.bias_layer.bias.assign(value["final_logits_bias"])
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.get_input_embeddings()
|
||||
@ -2454,7 +2468,7 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
)
|
||||
lm_logits = self.led.shared(outputs[0], mode="linear")
|
||||
lm_logits = tf.matmul(outputs[0], self.led.shared.weights, transpose_b=True)
|
||||
lm_logits = self.bias_layer(lm_logits)
|
||||
masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits)
|
||||
|
||||
|
@ -34,13 +34,12 @@ from ...modeling_tf_utils import (
|
||||
DUMMY_INPUTS,
|
||||
TFCausalLanguageModelingLoss,
|
||||
TFPreTrainedModel,
|
||||
TFSharedEmbeddings,
|
||||
TFWrappedEmbeddings,
|
||||
keras_serializable,
|
||||
unpack_inputs,
|
||||
)
|
||||
from ...tf_utils import shape_list, stable_softmax
|
||||
from ...utils import (
|
||||
ContextManagers,
|
||||
add_code_sample_docstrings,
|
||||
add_end_docstrings,
|
||||
add_start_docstrings,
|
||||
@ -684,7 +683,7 @@ class TFMarianEncoder(tf.keras.layers.Layer):
|
||||
config: MarianConfig
|
||||
"""
|
||||
|
||||
def __init__(self, config: MarianConfig, embed_tokens: Optional[TFSharedEmbeddings] = None, **kwargs):
|
||||
def __init__(self, config: MarianConfig, embed_tokens: Optional[tf.keras.layers.Embedding] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.config = config
|
||||
self.dropout = tf.keras.layers.Dropout(config.dropout)
|
||||
@ -772,17 +771,25 @@ class TFMarianEncoder(tf.keras.layers.Layer):
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if inputs_embeds is None:
|
||||
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
|
||||
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
|
||||
tf.debugging.assert_less(
|
||||
input_ids,
|
||||
tf.cast(self.embed_tokens.vocab_size, dtype=input_ids.dtype),
|
||||
message=(
|
||||
"input_ids must be smaller than the embedding layer's input dimension (got"
|
||||
f" {tf.math.reduce_max(input_ids)} >= {self.embed_tokens.vocab_size})"
|
||||
),
|
||||
)
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
# if `self.embed_tokens.load_weight_prefix` is set, runs the embedding operation with the correct name
|
||||
# scope, so that its weights are registered with the desired name for loading/storing. When `tf.name_scope`
|
||||
# is used with a name ending in `/`, that name replaces the current name scope.
|
||||
# (embeddings with tf.name_scope: self.embed_tokens.load_weight_prefix/self.embed_tokens.name/embeddings:0)
|
||||
context = []
|
||||
if hasattr(self.embed_tokens, "load_weight_prefix"):
|
||||
context.append(tf.name_scope(self.embed_tokens.load_weight_prefix + "/"))
|
||||
with ContextManagers(context):
|
||||
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
|
||||
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
|
||||
tf.debugging.assert_less(
|
||||
input_ids,
|
||||
tf.cast(self.embed_tokens.input_dim, dtype=input_ids.dtype),
|
||||
message=(
|
||||
"input_ids must be smaller than the embedding layer's input dimension (got"
|
||||
f" {tf.math.reduce_max(input_ids)} >= {self.embed_tokens.input_dim})"
|
||||
),
|
||||
)
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
|
||||
embed_pos = self.embed_positions(input_shape)
|
||||
hidden_states = inputs_embeds + embed_pos
|
||||
@ -849,7 +856,7 @@ class TFMarianDecoder(tf.keras.layers.Layer):
|
||||
embed_tokens: output embedding
|
||||
"""
|
||||
|
||||
def __init__(self, config: MarianConfig, embed_tokens: Optional[TFSharedEmbeddings] = None, **kwargs):
|
||||
def __init__(self, config: MarianConfig, embed_tokens: Optional[tf.keras.layers.Embedding] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
@ -977,17 +984,25 @@ class TFMarianDecoder(tf.keras.layers.Layer):
|
||||
positions = self.embed_positions(input_shape, position_ids=position_ids)
|
||||
|
||||
if inputs_embeds is None:
|
||||
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
|
||||
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
|
||||
tf.debugging.assert_less(
|
||||
input_ids,
|
||||
tf.cast(self.embed_tokens.vocab_size, dtype=input_ids.dtype),
|
||||
message=(
|
||||
"input_ids must be smaller than the embedding layer's input dimension (got"
|
||||
f" {tf.math.reduce_max(input_ids)} >= {self.embed_tokens.vocab_size})"
|
||||
),
|
||||
)
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
# if `self.embed_tokens.load_weight_prefix` is set, runs the embedding operation with the correct name
|
||||
# scope, so that its weights are registered with the desired name for loading/storing. When `tf.name_scope`
|
||||
# is used with a name ending in `/`, that name replaces the current name scope.
|
||||
# (embeddings with tf.name_scope: self.embed_tokens.load_weight_prefix/self.embed_tokens.name/embeddings:0)
|
||||
context = []
|
||||
if hasattr(self.embed_tokens, "load_weight_prefix"):
|
||||
context.append(tf.name_scope(self.embed_tokens.load_weight_prefix + "/"))
|
||||
with ContextManagers(context):
|
||||
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
|
||||
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
|
||||
tf.debugging.assert_less(
|
||||
input_ids,
|
||||
tf.cast(self.embed_tokens.input_dim, dtype=input_ids.dtype),
|
||||
message=(
|
||||
"input_ids must be smaller than the embedding layer's input dimension (got"
|
||||
f" {tf.math.reduce_max(input_ids)} >= {self.embed_tokens.input_dim})"
|
||||
),
|
||||
)
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
@ -1079,32 +1094,25 @@ class TFMarianMainLayer(tf.keras.layers.Layer):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.config = config
|
||||
self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, config.pad_token_id, name="model.shared")
|
||||
self.shared = tf.keras.layers.Embedding(
|
||||
input_dim=config.vocab_size,
|
||||
output_dim=config.d_model,
|
||||
embeddings_initializer=tf.keras.initializers.TruncatedNormal(stddev=self.config.init_std),
|
||||
name="model.shared",
|
||||
)
|
||||
# Additional attribute to specify the expected name scope of the layer (for loading/storing weights)
|
||||
self.shared.load_weight_prefix = "model.shared"
|
||||
|
||||
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
|
||||
pass
|
||||
|
||||
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
|
||||
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
|
||||
embed_tokens.vocab_size = self.shared.vocab_size
|
||||
embed_tokens.hidden_size = self.shared.hidden_size
|
||||
|
||||
self.encoder = TFMarianEncoder(config, embed_tokens, name="encoder")
|
||||
self.decoder = TFMarianDecoder(config, embed_tokens, name="decoder")
|
||||
self.encoder = TFMarianEncoder(config, self.shared, name="encoder")
|
||||
self.decoder = TFMarianDecoder(config, self.shared, name="decoder")
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.shared
|
||||
|
||||
def set_input_embeddings(self, new_embeddings):
|
||||
self.shared.weight = new_embeddings
|
||||
self.shared.vocab_size = self.shared.weight.shape[0]
|
||||
# retrieve correct absolute scope for embed token wrapper
|
||||
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
|
||||
pass
|
||||
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
|
||||
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
|
||||
self.encoder.set_embed_tokens(embed_tokens)
|
||||
self.decoder.set_embed_tokens(embed_tokens)
|
||||
self.shared = new_embeddings
|
||||
self.encoder.embed_tokens = self.shared
|
||||
self.decoder.embed_tokens = self.shared
|
||||
|
||||
@unpack_inputs
|
||||
def call(
|
||||
@ -1314,7 +1322,6 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
self.bias_layer = BiasLayer(
|
||||
name="final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False
|
||||
)
|
||||
self.final_logits_bias = self.bias_layer.bias # alias to keep the same interface with PT
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model.decoder
|
||||
@ -1329,10 +1336,15 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
self.set_input_embeddings(value)
|
||||
|
||||
def get_bias(self):
|
||||
return {"final_logits_bias": self.final_logits_bias}
|
||||
return {"final_logits_bias": self.bias_layer.bias}
|
||||
|
||||
def set_bias(self, value):
|
||||
self.final_logits_bias = value["final_logits_bias"]
|
||||
# Replaces the existing layers containing bias for correct (de)serialization.
|
||||
vocab_size = value["final_logits_bias"].shape[-1]
|
||||
self.bias_layer = BiasLayer(
|
||||
name="final_logits_bias", shape=[1, vocab_size], initializer="zeros", trainable=False
|
||||
)
|
||||
self.bias_layer.bias.assign(value["final_logits_bias"])
|
||||
|
||||
@unpack_inputs
|
||||
@add_start_docstrings_to_model_forward(MARIAN_INPUTS_DOCSTRING)
|
||||
@ -1400,7 +1412,7 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
)
|
||||
lm_logits = self.model.shared(outputs[0], mode="linear")
|
||||
lm_logits = tf.matmul(outputs[0], self.model.shared.weights, transpose_b=True)
|
||||
lm_logits = self.bias_layer(lm_logits)
|
||||
masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits)
|
||||
|
||||
|
@ -34,13 +34,12 @@ from ...modeling_tf_utils import (
|
||||
TFCausalLanguageModelingLoss,
|
||||
TFModelInputType,
|
||||
TFPreTrainedModel,
|
||||
TFSharedEmbeddings,
|
||||
TFWrappedEmbeddings,
|
||||
keras_serializable,
|
||||
unpack_inputs,
|
||||
)
|
||||
from ...tf_utils import shape_list, stable_softmax
|
||||
from ...utils import (
|
||||
ContextManagers,
|
||||
add_code_sample_docstrings,
|
||||
add_end_docstrings,
|
||||
add_start_docstrings,
|
||||
@ -118,7 +117,7 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_tf_bart.TFBartLearnedPositionalEmbedding with Bart->MBart
|
||||
class TFMBartLearnedPositionalEmbedding(TFSharedEmbeddings):
|
||||
class TFMBartLearnedPositionalEmbedding(tf.keras.layers.Embedding):
|
||||
"""
|
||||
This module learns positional embeddings up to a fixed maximum size.
|
||||
"""
|
||||
@ -667,7 +666,7 @@ class TFMBartEncoder(tf.keras.layers.Layer):
|
||||
config: MBartConfig
|
||||
"""
|
||||
|
||||
def __init__(self, config: MBartConfig, embed_tokens: Optional[TFSharedEmbeddings] = None, **kwargs):
|
||||
def __init__(self, config: MBartConfig, embed_tokens: Optional[tf.keras.layers.Embedding] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.config = config
|
||||
self.dropout = tf.keras.layers.Dropout(config.dropout)
|
||||
@ -757,17 +756,25 @@ class TFMBartEncoder(tf.keras.layers.Layer):
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if inputs_embeds is None:
|
||||
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
|
||||
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
|
||||
tf.debugging.assert_less(
|
||||
input_ids,
|
||||
tf.cast(self.embed_tokens.vocab_size, dtype=input_ids.dtype),
|
||||
message=(
|
||||
"input_ids must be smaller than the embedding layer's input dimension (got"
|
||||
f" {tf.math.reduce_max(input_ids)} >= {self.embed_tokens.vocab_size})"
|
||||
),
|
||||
)
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
# if `self.embed_tokens.load_weight_prefix` is set, runs the embedding operation with the correct name
|
||||
# scope, so that its weights are registered with the desired name for loading/storing. When `tf.name_scope`
|
||||
# is used with a name ending in `/`, that name replaces the current name scope.
|
||||
# (embeddings with tf.name_scope: self.embed_tokens.load_weight_prefix/self.embed_tokens.name/embeddings:0)
|
||||
context = []
|
||||
if hasattr(self.embed_tokens, "load_weight_prefix"):
|
||||
context.append(tf.name_scope(self.embed_tokens.load_weight_prefix + "/"))
|
||||
with ContextManagers(context):
|
||||
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
|
||||
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
|
||||
tf.debugging.assert_less(
|
||||
input_ids,
|
||||
tf.cast(self.embed_tokens.input_dim, dtype=input_ids.dtype),
|
||||
message=(
|
||||
"input_ids must be smaller than the embedding layer's input dimension (got"
|
||||
f" {tf.math.reduce_max(input_ids)} >= {self.embed_tokens.input_dim})"
|
||||
),
|
||||
)
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
|
||||
embed_pos = self.embed_positions(input_shape)
|
||||
hidden_states = inputs_embeds + embed_pos
|
||||
@ -837,7 +844,7 @@ class TFMBartDecoder(tf.keras.layers.Layer):
|
||||
embed_tokens: output embedding
|
||||
"""
|
||||
|
||||
def __init__(self, config: MBartConfig, embed_tokens: Optional[TFSharedEmbeddings] = None, **kwargs):
|
||||
def __init__(self, config: MBartConfig, embed_tokens: Optional[tf.keras.layers.Embedding] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
@ -969,17 +976,25 @@ class TFMBartDecoder(tf.keras.layers.Layer):
|
||||
positions = self.embed_positions(input_shape, position_ids=position_ids)
|
||||
|
||||
if inputs_embeds is None:
|
||||
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
|
||||
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
|
||||
tf.debugging.assert_less(
|
||||
input_ids,
|
||||
tf.cast(self.embed_tokens.vocab_size, dtype=input_ids.dtype),
|
||||
message=(
|
||||
"input_ids must be smaller than the embedding layer's input dimension (got"
|
||||
f" {tf.math.reduce_max(input_ids)} >= {self.embed_tokens.vocab_size})"
|
||||
),
|
||||
)
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
# if `self.embed_tokens.load_weight_prefix` is set, runs the embedding operation with the correct name
|
||||
# scope, so that its weights are registered with the desired name for loading/storing. When `tf.name_scope`
|
||||
# is used with a name ending in `/`, that name replaces the current name scope.
|
||||
# (embeddings with tf.name_scope: self.embed_tokens.load_weight_prefix/self.embed_tokens.name/embeddings:0)
|
||||
context = []
|
||||
if hasattr(self.embed_tokens, "load_weight_prefix"):
|
||||
context.append(tf.name_scope(self.embed_tokens.load_weight_prefix + "/"))
|
||||
with ContextManagers(context):
|
||||
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
|
||||
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
|
||||
tf.debugging.assert_less(
|
||||
input_ids,
|
||||
tf.cast(self.embed_tokens.input_dim, dtype=input_ids.dtype),
|
||||
message=(
|
||||
"input_ids must be smaller than the embedding layer's input dimension (got"
|
||||
f" {tf.math.reduce_max(input_ids)} >= {self.embed_tokens.input_dim})"
|
||||
),
|
||||
)
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
@ -1074,32 +1089,25 @@ class TFMBartMainLayer(tf.keras.layers.Layer):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.config = config
|
||||
self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, config.pad_token_id, name="model.shared")
|
||||
self.shared = tf.keras.layers.Embedding(
|
||||
input_dim=config.vocab_size,
|
||||
output_dim=config.d_model,
|
||||
embeddings_initializer=tf.keras.initializers.TruncatedNormal(stddev=self.config.init_std),
|
||||
name="model.shared",
|
||||
)
|
||||
# Additional attribute to specify the expected name scope of the layer (for loading/storing weights)
|
||||
self.shared.load_weight_prefix = "model.shared"
|
||||
|
||||
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
|
||||
pass
|
||||
|
||||
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
|
||||
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
|
||||
embed_tokens.vocab_size = self.shared.vocab_size
|
||||
embed_tokens.hidden_size = self.shared.hidden_size
|
||||
|
||||
self.encoder = TFMBartEncoder(config, embed_tokens, name="encoder")
|
||||
self.decoder = TFMBartDecoder(config, embed_tokens, name="decoder")
|
||||
self.encoder = TFMBartEncoder(config, self.shared, name="encoder")
|
||||
self.decoder = TFMBartDecoder(config, self.shared, name="decoder")
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.shared
|
||||
|
||||
def set_input_embeddings(self, new_embeddings):
|
||||
self.shared.weight = new_embeddings
|
||||
self.shared.vocab_size = self.shared.weight.shape[0]
|
||||
# retrieve correct absolute scope for embed token wrapper
|
||||
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
|
||||
pass
|
||||
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
|
||||
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
|
||||
self.encoder.set_embed_tokens(embed_tokens)
|
||||
self.decoder.set_embed_tokens(embed_tokens)
|
||||
self.shared = new_embeddings
|
||||
self.encoder.embed_tokens = self.shared
|
||||
self.decoder.embed_tokens = self.shared
|
||||
|
||||
@unpack_inputs
|
||||
def call(
|
||||
@ -1313,7 +1321,6 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
|
||||
self.bias_layer = BiasLayer(
|
||||
name="final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False
|
||||
)
|
||||
self.final_logits_bias = self.bias_layer.bias # alias to keep the same interface with PT
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model.decoder
|
||||
@ -1328,10 +1335,15 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
|
||||
self.set_input_embeddings(value)
|
||||
|
||||
def get_bias(self):
|
||||
return {"final_logits_bias": self.final_logits_bias}
|
||||
return {"final_logits_bias": self.bias_layer.bias}
|
||||
|
||||
def set_bias(self, value):
|
||||
self.final_logits_bias = value["final_logits_bias"]
|
||||
# Replaces the existing layers containing bias for correct (de)serialization.
|
||||
vocab_size = value["final_logits_bias"].shape[-1]
|
||||
self.bias_layer = BiasLayer(
|
||||
name="final_logits_bias", shape=[1, vocab_size], initializer="zeros", trainable=False
|
||||
)
|
||||
self.bias_layer.bias.assign(value["final_logits_bias"])
|
||||
|
||||
@unpack_inputs
|
||||
@add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING)
|
||||
@ -1397,7 +1409,7 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
)
|
||||
lm_logits = self.model.shared(outputs[0], mode="linear")
|
||||
lm_logits = tf.matmul(outputs[0], self.model.shared.weights, transpose_b=True)
|
||||
lm_logits = self.bias_layer(lm_logits)
|
||||
masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits)
|
||||
|
||||
|
@ -34,13 +34,12 @@ from ...modeling_tf_utils import (
|
||||
DUMMY_INPUTS,
|
||||
TFCausalLanguageModelingLoss,
|
||||
TFPreTrainedModel,
|
||||
TFSharedEmbeddings,
|
||||
TFWrappedEmbeddings,
|
||||
keras_serializable,
|
||||
unpack_inputs,
|
||||
)
|
||||
from ...tf_utils import shape_list, stable_softmax
|
||||
from ...utils import (
|
||||
ContextManagers,
|
||||
add_code_sample_docstrings,
|
||||
add_end_docstrings,
|
||||
add_start_docstrings,
|
||||
@ -686,7 +685,7 @@ class TFPegasusEncoder(tf.keras.layers.Layer):
|
||||
config: PegasusConfig
|
||||
"""
|
||||
|
||||
def __init__(self, config: PegasusConfig, embed_tokens: Optional[TFSharedEmbeddings] = None, **kwargs):
|
||||
def __init__(self, config: PegasusConfig, embed_tokens: Optional[tf.keras.layers.Embedding] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.config = config
|
||||
self.dropout = tf.keras.layers.Dropout(config.dropout)
|
||||
@ -775,17 +774,25 @@ class TFPegasusEncoder(tf.keras.layers.Layer):
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if inputs_embeds is None:
|
||||
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
|
||||
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
|
||||
tf.debugging.assert_less(
|
||||
input_ids,
|
||||
tf.cast(self.embed_tokens.vocab_size, dtype=input_ids.dtype),
|
||||
message=(
|
||||
"input_ids must be smaller than the embedding layer's input dimension (got"
|
||||
f" {tf.math.reduce_max(input_ids)} >= {self.embed_tokens.vocab_size})"
|
||||
),
|
||||
)
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
# if `self.embed_tokens.load_weight_prefix` is set, runs the embedding operation with the correct name
|
||||
# scope, so that its weights are registered with the desired name for loading/storing. When `tf.name_scope`
|
||||
# is used with a name ending in `/`, that name replaces the current name scope.
|
||||
# (embeddings with tf.name_scope: self.embed_tokens.load_weight_prefix/self.embed_tokens.name/embeddings:0)
|
||||
context = []
|
||||
if hasattr(self.embed_tokens, "load_weight_prefix"):
|
||||
context.append(tf.name_scope(self.embed_tokens.load_weight_prefix + "/"))
|
||||
with ContextManagers(context):
|
||||
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
|
||||
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
|
||||
tf.debugging.assert_less(
|
||||
input_ids,
|
||||
tf.cast(self.embed_tokens.input_dim, dtype=input_ids.dtype),
|
||||
message=(
|
||||
"input_ids must be smaller than the embedding layer's input dimension (got"
|
||||
f" {tf.math.reduce_max(input_ids)} >= {self.embed_tokens.input_dim})"
|
||||
),
|
||||
)
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
|
||||
embed_pos = self.embed_positions(input_shape)
|
||||
hidden_states = inputs_embeds + embed_pos
|
||||
@ -854,7 +861,7 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
|
||||
embed_tokens: output embedding
|
||||
"""
|
||||
|
||||
def __init__(self, config: PegasusConfig, embed_tokens: Optional[TFSharedEmbeddings] = None, **kwargs):
|
||||
def __init__(self, config: PegasusConfig, embed_tokens: Optional[tf.keras.layers.Embedding] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
@ -983,17 +990,25 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
|
||||
positions = self.embed_positions(input_shape, position_ids=position_ids)
|
||||
|
||||
if inputs_embeds is None:
|
||||
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
|
||||
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
|
||||
tf.debugging.assert_less(
|
||||
input_ids,
|
||||
tf.cast(self.embed_tokens.vocab_size, dtype=input_ids.dtype),
|
||||
message=(
|
||||
"input_ids must be smaller than the embedding layer's input dimension (got"
|
||||
f" {tf.math.reduce_max(input_ids)} >= {self.embed_tokens.vocab_size})"
|
||||
),
|
||||
)
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
# if `self.embed_tokens.load_weight_prefix` is set, runs the embedding operation with the correct name
|
||||
# scope, so that its weights are registered with the desired name for loading/storing. When `tf.name_scope`
|
||||
# is used with a name ending in `/`, that name replaces the current name scope.
|
||||
# (embeddings with tf.name_scope: self.embed_tokens.load_weight_prefix/self.embed_tokens.name/embeddings:0)
|
||||
context = []
|
||||
if hasattr(self.embed_tokens, "load_weight_prefix"):
|
||||
context.append(tf.name_scope(self.embed_tokens.load_weight_prefix + "/"))
|
||||
with ContextManagers(context):
|
||||
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
|
||||
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
|
||||
tf.debugging.assert_less(
|
||||
input_ids,
|
||||
tf.cast(self.embed_tokens.input_dim, dtype=input_ids.dtype),
|
||||
message=(
|
||||
"input_ids must be smaller than the embedding layer's input dimension (got"
|
||||
f" {tf.math.reduce_max(input_ids)} >= {self.embed_tokens.input_dim})"
|
||||
),
|
||||
)
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
@ -1087,32 +1102,25 @@ class TFPegasusMainLayer(tf.keras.layers.Layer):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.config = config
|
||||
self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, config.pad_token_id, name="model.shared")
|
||||
self.shared = tf.keras.layers.Embedding(
|
||||
input_dim=config.vocab_size,
|
||||
output_dim=config.d_model,
|
||||
embeddings_initializer=tf.keras.initializers.TruncatedNormal(stddev=self.config.init_std),
|
||||
name="model.shared",
|
||||
)
|
||||
# Additional attribute to specify the expected name scope of the layer (for loading/storing weights)
|
||||
self.shared.load_weight_prefix = "model.shared"
|
||||
|
||||
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
|
||||
pass
|
||||
|
||||
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
|
||||
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
|
||||
embed_tokens.vocab_size = self.shared.vocab_size
|
||||
embed_tokens.hidden_size = self.shared.hidden_size
|
||||
|
||||
self.encoder = TFPegasusEncoder(config, embed_tokens, name="encoder")
|
||||
self.decoder = TFPegasusDecoder(config, embed_tokens, name="decoder")
|
||||
self.encoder = TFPegasusEncoder(config, self.shared, name="encoder")
|
||||
self.decoder = TFPegasusDecoder(config, self.shared, name="decoder")
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.shared
|
||||
|
||||
def set_input_embeddings(self, new_embeddings):
|
||||
self.shared.weight = new_embeddings
|
||||
self.shared.vocab_size = self.shared.weight.shape[0]
|
||||
# retrieve correct absolute scope for embed token wrapper
|
||||
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
|
||||
pass
|
||||
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
|
||||
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
|
||||
self.encoder.set_embed_tokens(embed_tokens)
|
||||
self.decoder.set_embed_tokens(embed_tokens)
|
||||
self.shared = new_embeddings
|
||||
self.encoder.embed_tokens = self.shared
|
||||
self.decoder.embed_tokens = self.shared
|
||||
|
||||
@unpack_inputs
|
||||
def call(
|
||||
@ -1323,7 +1331,6 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
|
||||
self.bias_layer = BiasLayer(
|
||||
name="final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False
|
||||
)
|
||||
self.final_logits_bias = self.bias_layer.bias # alias to keep the same interface with PT
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model.decoder
|
||||
@ -1338,10 +1345,15 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
|
||||
self.set_input_embeddings(value)
|
||||
|
||||
def get_bias(self):
|
||||
return {"final_logits_bias": self.final_logits_bias}
|
||||
return {"final_logits_bias": self.bias_layer.bias}
|
||||
|
||||
def set_bias(self, value):
|
||||
self.final_logits_bias = value["final_logits_bias"]
|
||||
# Replaces the existing layers containing bias for correct (de)serialization.
|
||||
vocab_size = value["final_logits_bias"].shape[-1]
|
||||
self.bias_layer = BiasLayer(
|
||||
name="final_logits_bias", shape=[1, vocab_size], initializer="zeros", trainable=False
|
||||
)
|
||||
self.bias_layer.bias.assign(value["final_logits_bias"])
|
||||
|
||||
@unpack_inputs
|
||||
@add_start_docstrings_to_model_forward(PEGASUS_INPUTS_DOCSTRING)
|
||||
@ -1409,7 +1421,7 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
)
|
||||
lm_logits = self.model.shared(outputs[0], mode="linear")
|
||||
lm_logits = tf.matmul(outputs[0], self.model.shared.weights, transpose_b=True)
|
||||
lm_logits = self.bias_layer(lm_logits)
|
||||
masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits)
|
||||
|
||||
|
@ -36,8 +36,7 @@ from ...modeling_tf_utils import (
|
||||
TFCausalLanguageModelingLoss,
|
||||
TFModelInputType,
|
||||
TFPreTrainedModel,
|
||||
TFSharedEmbeddings,
|
||||
TFWrappedEmbeddings,
|
||||
get_initializer,
|
||||
keras_serializable,
|
||||
unpack_inputs,
|
||||
)
|
||||
@ -45,6 +44,7 @@ from ...tf_utils import shape_list, stable_softmax
|
||||
from ...utils import (
|
||||
DUMMY_INPUTS,
|
||||
DUMMY_MASK,
|
||||
ContextManagers,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
logging,
|
||||
@ -681,17 +681,25 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
||||
|
||||
if inputs_embeds is None:
|
||||
assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings"
|
||||
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
|
||||
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
|
||||
tf.debugging.assert_less(
|
||||
input_ids,
|
||||
tf.cast(self.embed_tokens.vocab_size, dtype=input_ids.dtype),
|
||||
message=(
|
||||
"input_ids must be smaller than the embedding layer's input dimension (got"
|
||||
f" {tf.math.reduce_max(input_ids)} >= {self.embed_tokens.vocab_size})"
|
||||
),
|
||||
)
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
# if `self.embed_tokens.load_weight_prefix` is set, runs the embedding operation with the correct name
|
||||
# scope, so that its weights are registered with the desired name for loading/storing. When `tf.name_scope`
|
||||
# is used with a name ending in `/`, that name replaces the current name scope.
|
||||
# (embeddings with tf.name_scope: self.embed_tokens.load_weight_prefix/self.embed_tokens.name/embeddings:0)
|
||||
context = []
|
||||
if hasattr(self.embed_tokens, "load_weight_prefix"):
|
||||
context.append(tf.name_scope(self.embed_tokens.load_weight_prefix + "/"))
|
||||
with ContextManagers(context):
|
||||
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
|
||||
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
|
||||
tf.debugging.assert_less(
|
||||
input_ids,
|
||||
tf.cast(self.embed_tokens.input_dim, dtype=input_ids.dtype),
|
||||
message=(
|
||||
"input_ids must be smaller than the embedding layer's input dimension (got"
|
||||
f" {tf.math.reduce_max(input_ids)} >= {self.embed_tokens.input_dim})"
|
||||
),
|
||||
)
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
batch_size, seq_length = input_shape
|
||||
|
||||
@ -898,21 +906,10 @@ class TFT5PreTrainedModel(TFPreTrainedModel):
|
||||
return self.shared
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
try:
|
||||
self.shared.weight = value
|
||||
except AttributeError:
|
||||
self(self.dummy_inputs)
|
||||
self.shared.weight = value
|
||||
|
||||
self.shared.vocab_size = shape_list(value)[0]
|
||||
# retrieve correct absolute scope for embed token wrapper
|
||||
with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name:
|
||||
pass
|
||||
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
|
||||
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
|
||||
self.encoder.embed_tokens = embed_tokens
|
||||
self.shared = value
|
||||
self.encoder.embed_tokens = self.shared
|
||||
if hasattr(self, "decoder"):
|
||||
self.decoder.embed_tokens = embed_tokens
|
||||
self.decoder.embed_tokens = self.shared
|
||||
|
||||
def _shift_right(self, input_ids):
|
||||
decoder_start_token_id = self.config.decoder_start_token_id
|
||||
@ -1133,24 +1130,24 @@ num_heads))`.
|
||||
class TFT5Model(TFT5PreTrainedModel):
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
self.shared = TFSharedEmbeddings(
|
||||
config.vocab_size, config.d_model, name="shared", initializer_range=self.config.initializer_factor
|
||||
)
|
||||
|
||||
# retrieve correct absolute scope for embed token wrapper
|
||||
with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name:
|
||||
pass
|
||||
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
|
||||
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
|
||||
self.shared = tf.keras.layers.Embedding(
|
||||
input_dim=config.vocab_size,
|
||||
output_dim=config.d_model,
|
||||
embeddings_initializer=tf.keras.initializers.TruncatedNormal(self.config.initializer_factor),
|
||||
name="shared",
|
||||
)
|
||||
# Additional attribute to specify the expected name scope of the layer (for loading/storing weights)
|
||||
self.shared.load_weight_prefix = "shared"
|
||||
|
||||
encoder_config = copy.deepcopy(config)
|
||||
encoder_config.use_cache = False
|
||||
self.encoder = TFT5MainLayer(encoder_config, embed_tokens, name="encoder")
|
||||
self.encoder = TFT5MainLayer(encoder_config, self.shared, name="encoder")
|
||||
|
||||
decoder_config = copy.deepcopy(config)
|
||||
decoder_config.is_decoder = True
|
||||
decoder_config.num_layers = config.num_decoder_layers
|
||||
self.decoder = TFT5MainLayer(decoder_config, embed_tokens, name="decoder")
|
||||
self.decoder = TFT5MainLayer(decoder_config, self.shared, name="decoder")
|
||||
|
||||
def get_encoder(self):
|
||||
return self.encoder
|
||||
@ -1286,24 +1283,23 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
self.model_dim = config.d_model
|
||||
self.shared = TFSharedEmbeddings(
|
||||
config.vocab_size, config.d_model, name="shared", initializer_range=self.config.initializer_factor
|
||||
self.shared = tf.keras.layers.Embedding(
|
||||
config.vocab_size,
|
||||
config.d_model,
|
||||
name="shared",
|
||||
embeddings_initializer=get_initializer(self.config.initializer_factor),
|
||||
)
|
||||
|
||||
# retrieve correct absolute scope for embed token wrapper
|
||||
with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name:
|
||||
pass
|
||||
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
|
||||
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
|
||||
# Additional attribute to specify the expected name scope of the layer (for loading/storing weights)
|
||||
self.shared.load_weight_prefix = "shared"
|
||||
|
||||
encoder_config = copy.deepcopy(config)
|
||||
encoder_config.use_cache = False
|
||||
self.encoder = TFT5MainLayer(encoder_config, embed_tokens, name="encoder")
|
||||
self.encoder = TFT5MainLayer(encoder_config, self.shared, name="encoder")
|
||||
|
||||
decoder_config = copy.deepcopy(config)
|
||||
decoder_config.is_decoder = True
|
||||
decoder_config.num_layers = config.num_decoder_layers
|
||||
self.decoder = TFT5MainLayer(decoder_config, embed_tokens, name="decoder")
|
||||
self.decoder = TFT5MainLayer(decoder_config, self.shared, name="decoder")
|
||||
|
||||
if not config.tie_word_embeddings:
|
||||
lm_head_initializer = tf.keras.initializers.RandomNormal(mean=0, stddev=config.initializer_factor)
|
||||
@ -1435,7 +1431,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
||||
# T5v1.1 does not tie output word embeddings and thus does not require downscaling
|
||||
if self.config.tie_word_embeddings:
|
||||
sequence_output = sequence_output * (self.model_dim**-0.5)
|
||||
logits = self.shared(sequence_output, mode="linear")
|
||||
logits = tf.matmul(sequence_output, self.shared.weights, transpose_b=True)
|
||||
else:
|
||||
logits = self.lm_head(sequence_output)
|
||||
|
||||
@ -1564,19 +1560,18 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
||||
class TFT5EncoderModel(TFT5PreTrainedModel):
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
self.shared = TFSharedEmbeddings(
|
||||
config.vocab_size, config.d_model, name="shared", initializer_range=self.config.initializer_factor
|
||||
self.shared = tf.keras.layers.Embedding(
|
||||
config.vocab_size,
|
||||
config.d_model,
|
||||
name="shared",
|
||||
embeddings_initializer=get_initializer(self.config.initializer_factor),
|
||||
)
|
||||
|
||||
# retrieve correct absolute scope for embed token wrapper
|
||||
with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name:
|
||||
pass
|
||||
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
|
||||
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
|
||||
# Additional attribute to specify the expected name scope of the layer (for loading/storing weights)
|
||||
self.shared.load_weight_prefix = "shared"
|
||||
|
||||
encoder_config = copy.deepcopy(config)
|
||||
encoder_config.use_cache = False
|
||||
self.encoder = TFT5MainLayer(encoder_config, embed_tokens, name="encoder")
|
||||
self.encoder = TFT5MainLayer(encoder_config, self.shared, name="encoder")
|
||||
|
||||
@property
|
||||
def dummy_inputs(self):
|
||||
|
@ -1675,13 +1675,11 @@ from ...modeling_tf_outputs import (
|
||||
from ...modeling_tf_utils import (
|
||||
DUMMY_INPUTS,
|
||||
TFPreTrainedModel,
|
||||
TFSharedEmbeddings,
|
||||
TFWrappedEmbeddings,
|
||||
keras_serializable,
|
||||
unpack_inputs,
|
||||
)
|
||||
from ...tf_utils import shape_list, stable_softmax
|
||||
from ...utils import logging
|
||||
from ...utils import ContextManagers, logging
|
||||
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config
|
||||
|
||||
|
||||
@ -1747,7 +1745,7 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
|
||||
return (one_cst - expanded_mask) * LARGE_NEGATIVE
|
||||
|
||||
|
||||
class TF{{cookiecutter.camelcase_modelname}}LearnedPositionalEmbedding(TFSharedEmbeddings):
|
||||
class TF{{cookiecutter.camelcase_modelname}}LearnedPositionalEmbedding(tf.keras.layers.Embedding):
|
||||
"""
|
||||
This module learns positional embeddings up to a fixed maximum size.
|
||||
"""
|
||||
@ -1757,12 +1755,10 @@ class TF{{cookiecutter.camelcase_modelname}}LearnedPositionalEmbedding(TFSharedE
|
||||
|
||||
def call(self, input_shape: tf.TensorShape, past_key_values_length: int = 0):
|
||||
"""Input is expected to be of size [bsz x seqlen]."""
|
||||
bsz, seq_len = input_shape[:2]
|
||||
|
||||
positions = tf.range(
|
||||
past_key_values_length, seq_len + past_key_values_length, delta=1, name="range"
|
||||
)
|
||||
return super().call(positions)
|
||||
seq_len = input_shape[1]
|
||||
position_ids = tf.range(seq_len, delta=1, name="range")
|
||||
position_ids += past_key_values_length
|
||||
return super().call(tf.cast(position_ids, dtype=tf.int32))
|
||||
|
||||
|
||||
class TF{{cookiecutter.camelcase_modelname}}Attention(tf.keras.layers.Layer):
|
||||
@ -2226,7 +2222,7 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
|
||||
config: {{cookiecutter.camelcase_modelname}}Config
|
||||
"""
|
||||
|
||||
def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, embed_tokens: Optional[TFSharedEmbeddings] = None, **kwargs):
|
||||
def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, embed_tokens: Optional[tf.keras.layers.Embedding] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.config = config
|
||||
self.dropout = tf.keras.layers.Dropout(config.dropout)
|
||||
@ -2315,17 +2311,25 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if inputs_embeds is None:
|
||||
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
|
||||
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
|
||||
tf.debugging.assert_less(
|
||||
input_ids,
|
||||
tf.cast(self.embed_tokens.vocab_size, dtype=input_ids.dtype),
|
||||
message=(
|
||||
"input_ids must be smaller than the embedding layer's input dimension (got"
|
||||
f" {tf.math.reduce_max(input_ids)} >= {self.embed_tokens.vocab_size})"
|
||||
),
|
||||
)
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
# if `self.embed_tokens.load_weight_prefix` is set, runs the embedding operation with the correct name
|
||||
# scope, so that its weights are registered with the desired name for loading/storing. When `tf.name_scope`
|
||||
# is used with a name ending in `/`, that name replaces the current name scope.
|
||||
# (embeddings with tf.name_scope: self.embed_tokens.load_weight_prefix/self.embed_tokens.name/embeddings:0)
|
||||
context = []
|
||||
if hasattr(self.embed_tokens, "load_weight_prefix"):
|
||||
context.append(tf.name_scope(self.embed_tokens.load_weight_prefix + "/"))
|
||||
with ContextManagers(context):
|
||||
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
|
||||
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
|
||||
tf.debugging.assert_less(
|
||||
input_ids,
|
||||
tf.cast(self.embed_tokens.input_dim, dtype=input_ids.dtype),
|
||||
message=(
|
||||
"input_ids must be smaller than the embedding layer's input dimension (got"
|
||||
f" {tf.math.reduce_max(input_ids)} >= {self.embed_tokens.input_dim})"
|
||||
),
|
||||
)
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
|
||||
embed_pos = self.embed_positions(input_shape)
|
||||
hidden_states = inputs_embeds + embed_pos
|
||||
@ -2388,7 +2392,7 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
|
||||
embed_tokens: output embedding
|
||||
"""
|
||||
|
||||
def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, embed_tokens: Optional[TFSharedEmbeddings] = None, **kwargs):
|
||||
def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, embed_tokens: Optional[tf.keras.layers.Embedding] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
@ -2514,17 +2518,25 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
|
||||
positions = self.embed_positions(input_shape, past_key_values_length)
|
||||
|
||||
if inputs_embeds is None:
|
||||
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
|
||||
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
|
||||
tf.debugging.assert_less(
|
||||
input_ids,
|
||||
tf.cast(self.embed_tokens.vocab_size, dtype=input_ids.dtype),
|
||||
message=(
|
||||
"input_ids must be smaller than the embedding layer's input dimension (got"
|
||||
f" {tf.math.reduce_max(input_ids)} >= {self.embed_tokens.vocab_size})"
|
||||
),
|
||||
)
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
# if `self.embed_tokens.load_weight_prefix` is set, runs the embedding operation with the correct name
|
||||
# scope, so that its weights are registered with the desired name for loading/storing. When `tf.name_scope`
|
||||
# is used with a name ending in `/`, that name replaces the current name scope.
|
||||
# (embeddings with tf.name_scope: self.embed_tokens.load_weight_prefix/self.embed_tokens.name/embeddings:0)
|
||||
context = []
|
||||
if hasattr(self.embed_tokens, "load_weight_prefix"):
|
||||
context.append(tf.name_scope(self.embed_tokens.load_weight_prefix + "/"))
|
||||
with ContextManagers(context):
|
||||
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
|
||||
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
|
||||
tf.debugging.assert_less(
|
||||
input_ids,
|
||||
tf.cast(self.embed_tokens.input_dim, dtype=input_ids.dtype),
|
||||
message=(
|
||||
"input_ids must be smaller than the embedding layer's input dimension (got"
|
||||
f" {tf.math.reduce_max(input_ids)} >= {self.embed_tokens.input_dim})"
|
||||
),
|
||||
)
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
@ -2637,32 +2649,25 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.config = config
|
||||
self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, config.pad_token_id, name="model.shared")
|
||||
self.shared = tf.keras.layers.Embedding(
|
||||
input_dim=config.vocab_size,
|
||||
output_dim=config.d_model,
|
||||
embeddings_initializer=tf.keras.initializers.TruncatedNormal(stddev=self.config.init_std),
|
||||
name="model.shared"
|
||||
)
|
||||
# Additional attribute to specify the expected name scope of the layer (for loading/storing weights)
|
||||
self.shared.load_weight_prefix = "model.shared"
|
||||
|
||||
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
|
||||
pass
|
||||
|
||||
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
|
||||
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
|
||||
embed_tokens.vocab_size = self.shared.vocab_size
|
||||
embed_tokens.hidden_size = self.shared.hidden_size
|
||||
|
||||
self.encoder = TF{{cookiecutter.camelcase_modelname}}Encoder(config, embed_tokens, name="encoder")
|
||||
self.decoder = TF{{cookiecutter.camelcase_modelname}}Decoder(config, embed_tokens, name="decoder")
|
||||
self.encoder = TF{{cookiecutter.camelcase_modelname}}Encoder(config, self.shared, name="encoder")
|
||||
self.decoder = TF{{cookiecutter.camelcase_modelname}}Decoder(config, self.shared, name="decoder")
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.shared
|
||||
|
||||
def set_input_embeddings(self, new_embeddings):
|
||||
self.shared.weight = new_embeddings
|
||||
self.shared.vocab_size = self.shared.weight.shape[0]
|
||||
# retrieve correct absolute scope for embed token wrapper
|
||||
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
|
||||
pass
|
||||
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
|
||||
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
|
||||
self.encoder.set_embed_tokens(embed_tokens)
|
||||
self.decoder.set_embed_tokens(embed_tokens)
|
||||
self.shared = new_embeddings
|
||||
self.encoder.embed_tokens = self.shared
|
||||
self.decoder.embed_tokens = self.shared
|
||||
|
||||
@unpack_inputs
|
||||
def call(
|
||||
@ -2866,7 +2871,6 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
|
||||
self.bias_layer = BiasLayer(
|
||||
name="final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False
|
||||
)
|
||||
self.final_logits_bias = self.bias_layer.bias # alias to keep the same interface with PT
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model.decoder
|
||||
@ -2875,10 +2879,15 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
|
||||
return self.model.encoder
|
||||
|
||||
def get_bias(self):
|
||||
return {"final_logits_bias": self.final_logits_bias}
|
||||
return {"final_logits_bias": self.bias_layer.bias}
|
||||
|
||||
def set_bias(self, value):
|
||||
self.final_logits_bias = value["final_logits_bias"]
|
||||
# Replaces the existing layers containing bias for correct (de)serialization.
|
||||
vocab_size = value["final_logits_bias"].shape[-1]
|
||||
self.bias_layer = BiasLayer(
|
||||
name="final_logits_bias", shape=[1, vocab_size], initializer="zeros", trainable=False
|
||||
)
|
||||
self.bias_layer.bias.assign(value["final_logits_bias"])
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.get_input_embeddings()
|
||||
@ -2952,7 +2961,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
|
||||
return_dict=return_dict,
|
||||
training=training
|
||||
)
|
||||
lm_logits = self.model.shared(outputs[0], mode="linear")
|
||||
lm_logits = tf.matmul(outputs[0], self.model.shared.weights, transpose_b=True)
|
||||
lm_logits = self.bias_layer(lm_logits)
|
||||
masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits)
|
||||
|
||||
|
@ -889,69 +889,6 @@ class TF{{cookiecutter.camelcase_modelname}}ModelTest(TFModelTesterMixin, unitte
|
||||
name = model.get_bias()
|
||||
assert name is None
|
||||
|
||||
def test_resize_token_embeddings(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
def _get_word_embedding_weight(model, embedding_layer):
|
||||
if hasattr(embedding_layer, "weight"):
|
||||
return embedding_layer.weight
|
||||
else:
|
||||
# Here we build the word embeddings weights if not exists.
|
||||
# And then we retry to get the attribute once built.
|
||||
model(model.dummy_inputs)
|
||||
if hasattr(embedding_layer, "weight"):
|
||||
return embedding_layer.weight
|
||||
else:
|
||||
return None
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
for size in [config.vocab_size - 10, config.vocab_size + 10, None]:
|
||||
# build the embeddings
|
||||
model = model_class(config=config)
|
||||
old_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings())
|
||||
old_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings())
|
||||
old_final_logits_bias = model.get_bias()
|
||||
|
||||
# reshape the embeddings
|
||||
model.resize_token_embeddings(size)
|
||||
new_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings())
|
||||
new_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings())
|
||||
new_final_logits_bias = model.get_bias()
|
||||
|
||||
# check that the resized embeddings size matches the desired size.
|
||||
assert_size = size if size is not None else config.vocab_size
|
||||
|
||||
self.assertEqual(new_input_embeddings.shape[0], assert_size)
|
||||
|
||||
# check that weights remain the same after resizing
|
||||
models_equal = True
|
||||
for p1, p2 in zip(old_input_embeddings.value(), new_input_embeddings.value()):
|
||||
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
|
||||
models_equal = False
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
if old_output_embeddings is not None and new_output_embeddings is not None:
|
||||
self.assertEqual(new_output_embeddings.shape[0], assert_size)
|
||||
|
||||
models_equal = True
|
||||
for p1, p2 in zip(old_output_embeddings.value(), new_output_embeddings.value()):
|
||||
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
|
||||
models_equal = False
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
if old_final_logits_bias is not None and new_final_logits_bias is not None:
|
||||
old_final_logits_bias = old_final_logits_bias["final_logits_bias"]
|
||||
new_final_logits_bias = new_final_logits_bias["final_logits_bias"]
|
||||
self.assertEqual(new_final_logits_bias.shape[0], 1)
|
||||
self.assertEqual(new_final_logits_bias.shape[1], assert_size)
|
||||
|
||||
models_equal = True
|
||||
for old, new in zip(old_final_logits_bias.value(), new_final_logits_bias.value()):
|
||||
for p1, p2 in zip(old, new):
|
||||
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
|
||||
models_equal = False
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
@unittest.skip(reason="Template classes interact badly with this test.")
|
||||
def test_keras_fit(self):
|
||||
pass
|
||||
|
@ -217,87 +217,6 @@ class TFBlenderbotModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
def test_saved_model_creation(self):
|
||||
pass
|
||||
|
||||
def test_resize_token_embeddings(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
def _get_word_embedding_weight(model, embedding_layer):
|
||||
if hasattr(embedding_layer, "weight"):
|
||||
return embedding_layer.weight
|
||||
else:
|
||||
# Here we build the word embeddings weights if not exists.
|
||||
# And then we retry to get the attribute once built.
|
||||
model(model.dummy_inputs)
|
||||
if hasattr(embedding_layer, "weight"):
|
||||
return embedding_layer.weight
|
||||
else:
|
||||
return None
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
for size in [config.vocab_size - 10, config.vocab_size + 10, None]:
|
||||
# build the embeddings
|
||||
model = model_class(config=config)
|
||||
old_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings())
|
||||
old_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings())
|
||||
old_final_logits_bias = model.get_bias()
|
||||
|
||||
# reshape the embeddings
|
||||
model.resize_token_embeddings(size)
|
||||
new_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings())
|
||||
new_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings())
|
||||
new_final_logits_bias = model.get_bias()
|
||||
|
||||
# check that the resized embeddings size matches the desired size.
|
||||
assert_size = size if size is not None else config.vocab_size
|
||||
|
||||
self.assertEqual(new_input_embeddings.shape[0], assert_size)
|
||||
|
||||
# check that weights remain the same after resizing
|
||||
models_equal = True
|
||||
for p1, p2 in zip(old_input_embeddings.value(), new_input_embeddings.value()):
|
||||
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
|
||||
models_equal = False
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
if old_output_embeddings is not None and new_output_embeddings is not None:
|
||||
self.assertEqual(new_output_embeddings.shape[0], assert_size)
|
||||
|
||||
models_equal = True
|
||||
for p1, p2 in zip(old_output_embeddings.value(), new_output_embeddings.value()):
|
||||
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
|
||||
models_equal = False
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
if old_final_logits_bias is not None and new_final_logits_bias is not None:
|
||||
old_final_logits_bias = old_final_logits_bias["final_logits_bias"]
|
||||
new_final_logits_bias = new_final_logits_bias["final_logits_bias"]
|
||||
self.assertEqual(new_final_logits_bias.shape[0], 1)
|
||||
self.assertEqual(new_final_logits_bias.shape[1], assert_size)
|
||||
|
||||
models_equal = True
|
||||
for old, new in zip(old_final_logits_bias.value(), new_final_logits_bias.value()):
|
||||
for p1, p2 in zip(old, new):
|
||||
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
|
||||
models_equal = False
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
|
||||
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
|
||||
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
|
||||
if a is None and b is None:
|
||||
return True
|
||||
try:
|
||||
if tf.debugging.assert_near(a, b, atol=atol):
|
||||
return True
|
||||
raise
|
||||
except Exception:
|
||||
if len(prefix) > 0:
|
||||
prefix = f"{prefix}: "
|
||||
raise AssertionError(f"{prefix}{a} != {b}")
|
||||
|
||||
|
||||
def _long_tensor(tok_lst):
|
||||
return tf.constant(tok_lst, dtype=tf.int32)
|
||||
|
||||
|
||||
@require_tokenizers
|
||||
@require_tf
|
||||
|
@ -215,92 +215,11 @@ class TFBlenderbotSmallModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
name = model.get_bias()
|
||||
assert name is None
|
||||
|
||||
def test_resize_token_embeddings(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
def _get_word_embedding_weight(model, embedding_layer):
|
||||
if hasattr(embedding_layer, "weight"):
|
||||
return embedding_layer.weight
|
||||
else:
|
||||
# Here we build the word embeddings weights if not exists.
|
||||
# And then we retry to get the attribute once built.
|
||||
model(model.dummy_inputs)
|
||||
if hasattr(embedding_layer, "weight"):
|
||||
return embedding_layer.weight
|
||||
else:
|
||||
return None
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
for size in [config.vocab_size - 10, config.vocab_size + 10, None]:
|
||||
# build the embeddings
|
||||
model = model_class(config=config)
|
||||
old_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings())
|
||||
old_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings())
|
||||
old_final_logits_bias = model.get_bias()
|
||||
|
||||
# reshape the embeddings
|
||||
model.resize_token_embeddings(size)
|
||||
new_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings())
|
||||
new_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings())
|
||||
new_final_logits_bias = model.get_bias()
|
||||
|
||||
# check that the resized embeddings size matches the desired size.
|
||||
assert_size = size if size is not None else config.vocab_size
|
||||
|
||||
self.assertEqual(new_input_embeddings.shape[0], assert_size)
|
||||
|
||||
# check that weights remain the same after resizing
|
||||
models_equal = True
|
||||
for p1, p2 in zip(old_input_embeddings.value(), new_input_embeddings.value()):
|
||||
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
|
||||
models_equal = False
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
if old_output_embeddings is not None and new_output_embeddings is not None:
|
||||
self.assertEqual(new_output_embeddings.shape[0], assert_size)
|
||||
|
||||
models_equal = True
|
||||
for p1, p2 in zip(old_output_embeddings.value(), new_output_embeddings.value()):
|
||||
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
|
||||
models_equal = False
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
if old_final_logits_bias is not None and new_final_logits_bias is not None:
|
||||
old_final_logits_bias = old_final_logits_bias["final_logits_bias"]
|
||||
new_final_logits_bias = new_final_logits_bias["final_logits_bias"]
|
||||
self.assertEqual(new_final_logits_bias.shape[0], 1)
|
||||
self.assertEqual(new_final_logits_bias.shape[1], assert_size)
|
||||
|
||||
models_equal = True
|
||||
for old, new in zip(old_final_logits_bias.value(), new_final_logits_bias.value()):
|
||||
for p1, p2 in zip(old, new):
|
||||
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
|
||||
models_equal = False
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
@tooslow
|
||||
def test_saved_model_creation(self):
|
||||
pass
|
||||
|
||||
|
||||
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
|
||||
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
|
||||
if a is None and b is None:
|
||||
return True
|
||||
try:
|
||||
if tf.debugging.assert_near(a, b, atol=atol):
|
||||
return True
|
||||
raise
|
||||
except Exception:
|
||||
if len(prefix) > 0:
|
||||
prefix = f"{prefix}: "
|
||||
raise AssertionError(f"{prefix}{a} != {b}")
|
||||
|
||||
|
||||
def _long_tensor(tok_lst):
|
||||
return tf.constant(tok_lst, dtype=tf.int32)
|
||||
|
||||
|
||||
@require_tokenizers
|
||||
@require_tf
|
||||
class TFBlenderbot90MIntegrationTests(unittest.TestCase):
|
||||
|
@ -228,69 +228,6 @@ class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
name = model.get_bias()
|
||||
assert name is None
|
||||
|
||||
def test_resize_token_embeddings(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
def _get_word_embedding_weight(model, embedding_layer):
|
||||
if hasattr(embedding_layer, "weight"):
|
||||
return embedding_layer.weight
|
||||
else:
|
||||
# Here we build the word embeddings weights if not exists.
|
||||
# And then we retry to get the attribute once built.
|
||||
model(model.dummy_inputs)
|
||||
if hasattr(embedding_layer, "weight"):
|
||||
return embedding_layer.weight
|
||||
else:
|
||||
return None
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
for size in [config.vocab_size - 10, config.vocab_size + 10, None]:
|
||||
# build the embeddings
|
||||
model = model_class(config=config)
|
||||
old_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings())
|
||||
old_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings())
|
||||
old_final_logits_bias = model.get_bias()
|
||||
|
||||
# reshape the embeddings
|
||||
model.resize_token_embeddings(size)
|
||||
new_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings())
|
||||
new_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings())
|
||||
new_final_logits_bias = model.get_bias()
|
||||
|
||||
# check that the resized embeddings size matches the desired size.
|
||||
assert_size = size if size is not None else config.vocab_size
|
||||
|
||||
self.assertEqual(new_input_embeddings.shape[0], assert_size)
|
||||
|
||||
# check that weights remain the same after resizing
|
||||
models_equal = True
|
||||
for p1, p2 in zip(old_input_embeddings.value(), new_input_embeddings.value()):
|
||||
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
|
||||
models_equal = False
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
if old_output_embeddings is not None and new_output_embeddings is not None:
|
||||
self.assertEqual(new_output_embeddings.shape[0], assert_size)
|
||||
|
||||
models_equal = True
|
||||
for p1, p2 in zip(old_output_embeddings.value(), new_output_embeddings.value()):
|
||||
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
|
||||
models_equal = False
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
if old_final_logits_bias is not None and new_final_logits_bias is not None:
|
||||
old_final_logits_bias = old_final_logits_bias["final_logits_bias"]
|
||||
new_final_logits_bias = new_final_logits_bias["final_logits_bias"]
|
||||
self.assertEqual(new_final_logits_bias.shape[0], 1)
|
||||
self.assertEqual(new_final_logits_bias.shape[1], assert_size)
|
||||
|
||||
models_equal = True
|
||||
for old, new in zip(old_final_logits_bias.value(), new_final_logits_bias.value()):
|
||||
for p1, p2 in zip(old, new):
|
||||
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
|
||||
models_equal = False
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
def test_attention_outputs(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
inputs_dict["global_attention_mask"] = tf.zeros_like(inputs_dict["attention_mask"])
|
||||
@ -374,20 +311,6 @@ class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
pass
|
||||
|
||||
|
||||
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
|
||||
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
|
||||
if a is None and b is None:
|
||||
return True
|
||||
try:
|
||||
if tf.debugging.assert_near(a, b, atol=atol):
|
||||
return True
|
||||
raise
|
||||
except Exception:
|
||||
if len(prefix) > 0:
|
||||
prefix = f"{prefix}: "
|
||||
raise AssertionError(f"{prefix}{a} != {b}")
|
||||
|
||||
|
||||
def _long_tensor(tok_lst):
|
||||
return tf.constant(tok_lst, dtype=tf.int32)
|
||||
|
||||
|
@ -250,87 +250,6 @@ class TFMarianModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
def test_saved_model_creation(self):
|
||||
pass
|
||||
|
||||
def test_resize_token_embeddings(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
def _get_word_embedding_weight(model, embedding_layer):
|
||||
if hasattr(embedding_layer, "weight"):
|
||||
return embedding_layer.weight
|
||||
else:
|
||||
# Here we build the word embeddings weights if not exists.
|
||||
# And then we retry to get the attribute once built.
|
||||
model(model.dummy_inputs)
|
||||
if hasattr(embedding_layer, "weight"):
|
||||
return embedding_layer.weight
|
||||
else:
|
||||
return None
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
for size in [config.vocab_size - 10, config.vocab_size + 10, None]:
|
||||
# build the embeddings
|
||||
model = model_class(config=config)
|
||||
old_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings())
|
||||
old_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings())
|
||||
old_final_logits_bias = model.get_bias()
|
||||
|
||||
# reshape the embeddings
|
||||
model.resize_token_embeddings(size)
|
||||
new_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings())
|
||||
new_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings())
|
||||
new_final_logits_bias = model.get_bias()
|
||||
|
||||
# check that the resized embeddings size matches the desired size.
|
||||
assert_size = size if size is not None else config.vocab_size
|
||||
|
||||
self.assertEqual(new_input_embeddings.shape[0], assert_size)
|
||||
|
||||
# check that weights remain the same after resizing
|
||||
models_equal = True
|
||||
for p1, p2 in zip(old_input_embeddings.value(), new_input_embeddings.value()):
|
||||
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
|
||||
models_equal = False
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
if old_output_embeddings is not None and new_output_embeddings is not None:
|
||||
self.assertEqual(new_output_embeddings.shape[0], assert_size)
|
||||
|
||||
models_equal = True
|
||||
for p1, p2 in zip(old_output_embeddings.value(), new_output_embeddings.value()):
|
||||
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
|
||||
models_equal = False
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
if old_final_logits_bias is not None and new_final_logits_bias is not None:
|
||||
old_final_logits_bias = old_final_logits_bias["final_logits_bias"]
|
||||
new_final_logits_bias = new_final_logits_bias["final_logits_bias"]
|
||||
self.assertEqual(new_final_logits_bias.shape[0], 1)
|
||||
self.assertEqual(new_final_logits_bias.shape[1], assert_size)
|
||||
|
||||
models_equal = True
|
||||
for old, new in zip(old_final_logits_bias.value(), new_final_logits_bias.value()):
|
||||
for p1, p2 in zip(old, new):
|
||||
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
|
||||
models_equal = False
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
|
||||
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
|
||||
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
|
||||
if a is None and b is None:
|
||||
return True
|
||||
try:
|
||||
if tf.debugging.assert_near(a, b, atol=atol):
|
||||
return True
|
||||
raise
|
||||
except Exception:
|
||||
if len(prefix) > 0:
|
||||
prefix = f"{prefix}: "
|
||||
raise AssertionError(f"{prefix}{a} != {b}")
|
||||
|
||||
|
||||
def _long_tensor(tok_lst):
|
||||
return tf.constant(tok_lst, dtype=tf.int32)
|
||||
|
||||
|
||||
@require_tf
|
||||
class AbstractMarianIntegrationTest(unittest.TestCase):
|
||||
|
@ -218,95 +218,11 @@ class TFMBartModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
name = model.get_bias()
|
||||
assert name is None
|
||||
|
||||
def test_resize_token_embeddings(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
def _get_word_embedding_weight(model, embedding_layer):
|
||||
if hasattr(embedding_layer, "weight"):
|
||||
return embedding_layer.weight
|
||||
else:
|
||||
# Here we build the word embeddings weights if not exists.
|
||||
# And then we retry to get the attribute once built.
|
||||
model(model.dummy_inputs)
|
||||
if hasattr(embedding_layer, "weight"):
|
||||
return embedding_layer.weight
|
||||
else:
|
||||
return None
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
for size in [config.vocab_size - 10, config.vocab_size + 10, None]:
|
||||
# build the embeddings
|
||||
model = model_class(config=config)
|
||||
old_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings())
|
||||
old_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings())
|
||||
old_final_logits_bias = model.get_bias()
|
||||
|
||||
# reshape the embeddings
|
||||
model.resize_token_embeddings(size)
|
||||
new_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings())
|
||||
new_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings())
|
||||
new_final_logits_bias = model.get_bias()
|
||||
|
||||
# check that the resized embeddings size matches the desired size.
|
||||
assert_size = size if size is not None else config.vocab_size
|
||||
|
||||
self.assertEqual(new_input_embeddings.shape[0], assert_size)
|
||||
|
||||
# check that weights remain the same after resizing
|
||||
models_equal = True
|
||||
for p1, p2 in zip(old_input_embeddings.value(), new_input_embeddings.value()):
|
||||
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
|
||||
models_equal = False
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
if old_output_embeddings is not None and new_output_embeddings is not None:
|
||||
self.assertEqual(new_output_embeddings.shape[0], assert_size)
|
||||
|
||||
models_equal = True
|
||||
for p1, p2 in zip(old_output_embeddings.value(), new_output_embeddings.value()):
|
||||
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
|
||||
models_equal = False
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
if old_final_logits_bias is not None and new_final_logits_bias is not None:
|
||||
old_final_logits_bias = old_final_logits_bias["final_logits_bias"]
|
||||
new_final_logits_bias = new_final_logits_bias["final_logits_bias"]
|
||||
self.assertEqual(new_final_logits_bias.shape[0], 1)
|
||||
self.assertEqual(new_final_logits_bias.shape[1], assert_size)
|
||||
|
||||
models_equal = True
|
||||
for old, new in zip(old_final_logits_bias.value(), new_final_logits_bias.value()):
|
||||
for p1, p2 in zip(old, new):
|
||||
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
|
||||
models_equal = False
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
@tooslow
|
||||
def test_saved_model_creation(self):
|
||||
pass
|
||||
|
||||
|
||||
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
|
||||
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
|
||||
if a is None and b is None:
|
||||
return True
|
||||
try:
|
||||
if tf.debugging.assert_near(a, b, atol=atol):
|
||||
return True
|
||||
raise
|
||||
except Exception:
|
||||
if len(prefix) > 0:
|
||||
prefix = f"{prefix}: "
|
||||
raise AssertionError(f"{prefix}{a} != {b}")
|
||||
|
||||
|
||||
def _long_tensor(tok_lst):
|
||||
return tf.constant(tok_lst, dtype=tf.int32)
|
||||
|
||||
|
||||
TOLERANCE = 1e-4
|
||||
|
||||
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
@require_tf
|
||||
|
@ -248,87 +248,6 @@ class TFPegasusModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
def test_saved_model_creation(self):
|
||||
pass
|
||||
|
||||
def test_resize_token_embeddings(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
def _get_word_embedding_weight(model, embedding_layer):
|
||||
if hasattr(embedding_layer, "weight"):
|
||||
return embedding_layer.weight
|
||||
else:
|
||||
# Here we build the word embeddings weights if not exists.
|
||||
# And then we retry to get the attribute once built.
|
||||
model(model.dummy_inputs)
|
||||
if hasattr(embedding_layer, "weight"):
|
||||
return embedding_layer.weight
|
||||
else:
|
||||
return None
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
for size in [config.vocab_size - 10, config.vocab_size + 10, None]:
|
||||
# build the embeddings
|
||||
model = model_class(config=config)
|
||||
old_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings())
|
||||
old_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings())
|
||||
old_final_logits_bias = model.get_bias()
|
||||
|
||||
# reshape the embeddings
|
||||
model.resize_token_embeddings(size)
|
||||
new_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings())
|
||||
new_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings())
|
||||
new_final_logits_bias = model.get_bias()
|
||||
|
||||
# check that the resized embeddings size matches the desired size.
|
||||
assert_size = size if size is not None else config.vocab_size
|
||||
|
||||
self.assertEqual(new_input_embeddings.shape[0], assert_size)
|
||||
|
||||
# check that weights remain the same after resizing
|
||||
models_equal = True
|
||||
for p1, p2 in zip(old_input_embeddings.value(), new_input_embeddings.value()):
|
||||
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
|
||||
models_equal = False
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
if old_output_embeddings is not None and new_output_embeddings is not None:
|
||||
self.assertEqual(new_output_embeddings.shape[0], assert_size)
|
||||
|
||||
models_equal = True
|
||||
for p1, p2 in zip(old_output_embeddings.value(), new_output_embeddings.value()):
|
||||
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
|
||||
models_equal = False
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
if old_final_logits_bias is not None and new_final_logits_bias is not None:
|
||||
old_final_logits_bias = old_final_logits_bias["final_logits_bias"]
|
||||
new_final_logits_bias = new_final_logits_bias["final_logits_bias"]
|
||||
self.assertEqual(new_final_logits_bias.shape[0], 1)
|
||||
self.assertEqual(new_final_logits_bias.shape[1], assert_size)
|
||||
|
||||
models_equal = True
|
||||
for old, new in zip(old_final_logits_bias.value(), new_final_logits_bias.value()):
|
||||
for p1, p2 in zip(old, new):
|
||||
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
|
||||
models_equal = False
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
|
||||
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
|
||||
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
|
||||
if a is None and b is None:
|
||||
return True
|
||||
try:
|
||||
if tf.debugging.assert_near(a, b, atol=atol):
|
||||
return True
|
||||
raise
|
||||
except Exception:
|
||||
if len(prefix) > 0:
|
||||
prefix = f"{prefix}: "
|
||||
raise AssertionError(f"{prefix}{a} != {b}")
|
||||
|
||||
|
||||
def _long_tensor(tok_lst):
|
||||
return tf.constant(tok_lst, dtype=tf.int32)
|
||||
|
||||
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
|
@ -318,20 +318,6 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
# TODO: Fix head-masking according to PyTorch T5 model
|
||||
pass
|
||||
|
||||
@slow
|
||||
def test_resize_embeddings(self):
|
||||
model = TFT5ForConditionalGeneration.from_pretrained("t5-small")
|
||||
original_vocab_size = model.get_input_embeddings().weight.shape[0]
|
||||
# the vocab size is defined in the model config
|
||||
self.assertEqual(original_vocab_size, model.config.vocab_size)
|
||||
|
||||
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
||||
tokenizer.add_special_tokens({"bos_token": "", "eos_token": ""})
|
||||
model._resize_token_embeddings(len(tokenizer))
|
||||
# the vocab size is now resized to the length of the tokenizer, which is different from the original size
|
||||
self.assertEqual(model.get_input_embeddings().weight.shape[0], len(tokenizer))
|
||||
self.assertNotEqual(model.get_input_embeddings().weight.shape[0], original_vocab_size)
|
||||
|
||||
# This test is run in `TFT5EncoderOnlyModelTest`, where the main layer has the same inputs as the model
|
||||
@unittest.skip(reason="The inputs of the Main Layer are different.")
|
||||
def test_keras_save_load(self):
|
||||
|
Loading…
Reference in New Issue
Block a user