mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Move NoLayerEmbedTokens (#7945)
* Move NoLayerEmbedTokens * TFWrappedEmbeddings * Add comment
This commit is contained in:
parent
5ac07513e0
commit
0397619ac6
@ -33,6 +33,7 @@ from .modeling_tf_utils import (
|
||||
DUMMY_INPUTS,
|
||||
TFPreTrainedModel,
|
||||
TFSharedEmbeddings,
|
||||
TFWrappedEmbeddings,
|
||||
cast_bool_to_primitive,
|
||||
keras_serializable,
|
||||
shape_list,
|
||||
@ -132,36 +133,6 @@ LARGE_NEGATIVE = -1e8
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class _NoLayerEmbedTokens:
|
||||
"""
|
||||
this class wraps a the TFSharedEmbeddingTokens layer into a python 'no-keras-layer'
|
||||
class to avoid problem with weight restoring. Also it makes sure that the layer is
|
||||
called from the correct scope to avoid problem with saving/storing the correct weights
|
||||
"""
|
||||
|
||||
def __init__(self, layer, abs_scope_name=None):
|
||||
self._layer = layer
|
||||
self._abs_scope_name = abs_scope_name
|
||||
|
||||
def call(self, inputs, mode="embedding"):
|
||||
if self._abs_scope_name is None:
|
||||
return self._layer.call(inputs, mode)
|
||||
|
||||
# if an abs scope name is given to the embedding variable, call variable from absolute scope
|
||||
with tf.compat.v1.variable_scope(self._abs_scope_name, auxiliary_name_scope=False) as abs_scope_name:
|
||||
with tf.name_scope(abs_scope_name.original_name_scope):
|
||||
return self._layer.call(inputs, mode)
|
||||
|
||||
def __call__(self, inputs, mode="embedding"):
|
||||
if self._abs_scope_name is None:
|
||||
return self._layer(inputs, mode)
|
||||
|
||||
# if an abs scope name is given to the embedding variable, call variable from absolute scope
|
||||
with tf.compat.v1.variable_scope(self._abs_scope_name, auxiliary_name_scope=False) as abs_scope_name:
|
||||
with tf.name_scope(abs_scope_name.original_name_scope):
|
||||
return self._layer(inputs, mode)
|
||||
|
||||
|
||||
def create_position_ids_from_input_ids(input_ids, padding_idx):
|
||||
"""Replace non-padding symbols with their position numbers. Position numbers begin at
|
||||
padding_idx+1. Padding symbols are ignored. This is modified from fairseq's
|
||||
@ -826,7 +797,8 @@ class TFBartModel(TFPretrainedBartModel):
|
||||
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
|
||||
pass
|
||||
|
||||
embed_tokens = _NoLayerEmbedTokens(self.shared, abs_scope_name=shared_abs_scope_name)
|
||||
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
|
||||
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
|
||||
embed_tokens.vocab_size = self.shared.vocab_size
|
||||
embed_tokens.hidden_size = self.shared.hidden_size
|
||||
|
||||
|
@ -24,6 +24,8 @@ from typing import Tuple
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from transformers.modeling_tf_utils import TFWrappedEmbeddings
|
||||
|
||||
from .configuration_t5 import T5Config
|
||||
from .file_utils import (
|
||||
DUMMY_INPUTS,
|
||||
@ -505,36 +507,6 @@ class TFT5Block(tf.keras.layers.Layer):
|
||||
return outputs # hidden-states, present_key_value_states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
|
||||
|
||||
|
||||
class _NoLayerEmbedTokens:
|
||||
"""
|
||||
this class wraps a the TFSharedEmbeddingTokens layer into a python 'no-keras-layer'
|
||||
class to avoid problem with weight restoring. Also it makes sure that the layer is
|
||||
called from the correct scope to avoid problem with saving/storing the correct weights
|
||||
"""
|
||||
|
||||
def __init__(self, layer, abs_scope_name=None):
|
||||
self._layer = layer
|
||||
self._abs_scope_name = abs_scope_name
|
||||
|
||||
def call(self, inputs, mode="embedding"):
|
||||
if self._abs_scope_name is None:
|
||||
return self._layer.call(inputs, mode)
|
||||
|
||||
# if an abs scope name is given to the embedding variable, call variable from absolute scope
|
||||
with tf.compat.v1.variable_scope(self._abs_scope_name, auxiliary_name_scope=False) as abs_scope_name:
|
||||
with tf.name_scope(abs_scope_name.original_name_scope):
|
||||
return self._layer.call(inputs, mode)
|
||||
|
||||
def __call__(self, inputs, mode="embedding"):
|
||||
if self._abs_scope_name is None:
|
||||
return self._layer(inputs, mode)
|
||||
|
||||
# if an abs scope name is given to the embedding variable, call variable from absolute scope
|
||||
with tf.compat.v1.variable_scope(self._abs_scope_name, auxiliary_name_scope=False) as abs_scope_name:
|
||||
with tf.name_scope(abs_scope_name.original_name_scope):
|
||||
return self._layer(inputs, mode)
|
||||
|
||||
|
||||
####################################################
|
||||
# The full model without a specific pretrained or finetuning head is
|
||||
# provided as a tf.keras.layers.Layer usually called "TFT5MainLayer"
|
||||
@ -980,8 +952,8 @@ class TFT5Model(TFT5PreTrainedModel):
|
||||
# retrieve correct absolute scope for embed token wrapper
|
||||
with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name:
|
||||
pass
|
||||
|
||||
embed_tokens = _NoLayerEmbedTokens(self.shared, abs_scope_name=shared_abs_scope_name)
|
||||
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
|
||||
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
|
||||
|
||||
encoder_config = copy.deepcopy(config)
|
||||
encoder_config.use_cache = False
|
||||
@ -1003,7 +975,8 @@ class TFT5Model(TFT5PreTrainedModel):
|
||||
# retrieve correct absolute scope for embed token wrapper
|
||||
with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name:
|
||||
pass
|
||||
embed_tokens = _NoLayerEmbedTokens(self.shared, abs_scope_name=shared_abs_scope_name)
|
||||
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
|
||||
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
|
||||
self.encoder.set_embed_tokens(embed_tokens)
|
||||
self.decoder.set_embed_tokens(embed_tokens)
|
||||
|
||||
@ -1177,8 +1150,8 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
||||
# retrieve correct absolute scope for embed token wrapper
|
||||
with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name:
|
||||
pass
|
||||
|
||||
embed_tokens = _NoLayerEmbedTokens(self.shared, abs_scope_name=shared_abs_scope_name)
|
||||
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
|
||||
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
|
||||
|
||||
encoder_config = copy.deepcopy(config)
|
||||
encoder_config.use_cache = False
|
||||
@ -1199,7 +1172,8 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
||||
# retrieve correct absolute scope for embed token wrapper
|
||||
with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name:
|
||||
pass
|
||||
embed_tokens = _NoLayerEmbedTokens(self.shared, abs_scope_name=shared_abs_scope_name)
|
||||
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
|
||||
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
|
||||
self.encoder.set_embed_tokens(embed_tokens)
|
||||
self.decoder.set_embed_tokens(embed_tokens)
|
||||
|
||||
|
@ -1065,3 +1065,33 @@ def cast_bool_to_primitive(bool_variable: Union[tf.Tensor, bool], default_tensor
|
||||
|
||||
# else variable is bool
|
||||
return bool_variable
|
||||
|
||||
|
||||
class TFWrappedEmbeddings:
|
||||
"""
|
||||
this class wraps a the TFSharedEmbeddingTokens layer into a python 'no-keras-layer'
|
||||
class to avoid problem with weight restoring. Also it makes sure that the layer is
|
||||
called from the correct scope to avoid problem with saving/storing the correct weights
|
||||
"""
|
||||
|
||||
def __init__(self, layer, abs_scope_name=None):
|
||||
self._layer = layer
|
||||
self._abs_scope_name = abs_scope_name
|
||||
|
||||
def call(self, inputs, mode="embedding"):
|
||||
if self._abs_scope_name is None:
|
||||
return self._layer.call(inputs, mode)
|
||||
|
||||
# if an abs scope name is given to the embedding variable, call variable from absolute scope
|
||||
with tf.compat.v1.variable_scope(self._abs_scope_name, auxiliary_name_scope=False) as abs_scope_name:
|
||||
with tf.name_scope(abs_scope_name.original_name_scope):
|
||||
return self._layer.call(inputs, mode)
|
||||
|
||||
def __call__(self, inputs, mode="embedding"):
|
||||
if self._abs_scope_name is None:
|
||||
return self._layer(inputs, mode)
|
||||
|
||||
# if an abs scope name is given to the embedding variable, call variable from absolute scope
|
||||
with tf.compat.v1.variable_scope(self._abs_scope_name, auxiliary_name_scope=False) as abs_scope_name:
|
||||
with tf.name_scope(abs_scope_name.original_name_scope):
|
||||
return self._layer(inputs, mode)
|
||||
|
Loading…
Reference in New Issue
Block a user