Move NoLayerEmbedTokens (#7945)

* Move NoLayerEmbedTokens

* TFWrappedEmbeddings

* Add comment
This commit is contained in:
Sam Shleifer 2020-10-22 16:13:49 -04:00 committed by GitHub
parent 5ac07513e0
commit 0397619ac6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 43 additions and 67 deletions

View File

@ -33,6 +33,7 @@ from .modeling_tf_utils import (
DUMMY_INPUTS,
TFPreTrainedModel,
TFSharedEmbeddings,
TFWrappedEmbeddings,
cast_bool_to_primitive,
keras_serializable,
shape_list,
@ -132,36 +133,6 @@ LARGE_NEGATIVE = -1e8
logger = logging.get_logger(__name__)
class _NoLayerEmbedTokens:
"""
this class wraps a the TFSharedEmbeddingTokens layer into a python 'no-keras-layer'
class to avoid problem with weight restoring. Also it makes sure that the layer is
called from the correct scope to avoid problem with saving/storing the correct weights
"""
def __init__(self, layer, abs_scope_name=None):
self._layer = layer
self._abs_scope_name = abs_scope_name
def call(self, inputs, mode="embedding"):
if self._abs_scope_name is None:
return self._layer.call(inputs, mode)
# if an abs scope name is given to the embedding variable, call variable from absolute scope
with tf.compat.v1.variable_scope(self._abs_scope_name, auxiliary_name_scope=False) as abs_scope_name:
with tf.name_scope(abs_scope_name.original_name_scope):
return self._layer.call(inputs, mode)
def __call__(self, inputs, mode="embedding"):
if self._abs_scope_name is None:
return self._layer(inputs, mode)
# if an abs scope name is given to the embedding variable, call variable from absolute scope
with tf.compat.v1.variable_scope(self._abs_scope_name, auxiliary_name_scope=False) as abs_scope_name:
with tf.name_scope(abs_scope_name.original_name_scope):
return self._layer(inputs, mode)
def create_position_ids_from_input_ids(input_ids, padding_idx):
"""Replace non-padding symbols with their position numbers. Position numbers begin at
padding_idx+1. Padding symbols are ignored. This is modified from fairseq's
@ -826,7 +797,8 @@ class TFBartModel(TFPretrainedBartModel):
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
pass
embed_tokens = _NoLayerEmbedTokens(self.shared, abs_scope_name=shared_abs_scope_name)
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
embed_tokens.vocab_size = self.shared.vocab_size
embed_tokens.hidden_size = self.shared.hidden_size

View File

@ -24,6 +24,8 @@ from typing import Tuple
import tensorflow as tf
from transformers.modeling_tf_utils import TFWrappedEmbeddings
from .configuration_t5 import T5Config
from .file_utils import (
DUMMY_INPUTS,
@ -505,36 +507,6 @@ class TFT5Block(tf.keras.layers.Layer):
return outputs # hidden-states, present_key_value_states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
class _NoLayerEmbedTokens:
"""
this class wraps a the TFSharedEmbeddingTokens layer into a python 'no-keras-layer'
class to avoid problem with weight restoring. Also it makes sure that the layer is
called from the correct scope to avoid problem with saving/storing the correct weights
"""
def __init__(self, layer, abs_scope_name=None):
self._layer = layer
self._abs_scope_name = abs_scope_name
def call(self, inputs, mode="embedding"):
if self._abs_scope_name is None:
return self._layer.call(inputs, mode)
# if an abs scope name is given to the embedding variable, call variable from absolute scope
with tf.compat.v1.variable_scope(self._abs_scope_name, auxiliary_name_scope=False) as abs_scope_name:
with tf.name_scope(abs_scope_name.original_name_scope):
return self._layer.call(inputs, mode)
def __call__(self, inputs, mode="embedding"):
if self._abs_scope_name is None:
return self._layer(inputs, mode)
# if an abs scope name is given to the embedding variable, call variable from absolute scope
with tf.compat.v1.variable_scope(self._abs_scope_name, auxiliary_name_scope=False) as abs_scope_name:
with tf.name_scope(abs_scope_name.original_name_scope):
return self._layer(inputs, mode)
####################################################
# The full model without a specific pretrained or finetuning head is
# provided as a tf.keras.layers.Layer usually called "TFT5MainLayer"
@ -980,8 +952,8 @@ class TFT5Model(TFT5PreTrainedModel):
# retrieve correct absolute scope for embed token wrapper
with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name:
pass
embed_tokens = _NoLayerEmbedTokens(self.shared, abs_scope_name=shared_abs_scope_name)
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
encoder_config = copy.deepcopy(config)
encoder_config.use_cache = False
@ -1003,7 +975,8 @@ class TFT5Model(TFT5PreTrainedModel):
# retrieve correct absolute scope for embed token wrapper
with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name:
pass
embed_tokens = _NoLayerEmbedTokens(self.shared, abs_scope_name=shared_abs_scope_name)
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
self.encoder.set_embed_tokens(embed_tokens)
self.decoder.set_embed_tokens(embed_tokens)
@ -1177,8 +1150,8 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
# retrieve correct absolute scope for embed token wrapper
with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name:
pass
embed_tokens = _NoLayerEmbedTokens(self.shared, abs_scope_name=shared_abs_scope_name)
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
encoder_config = copy.deepcopy(config)
encoder_config.use_cache = False
@ -1199,7 +1172,8 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
# retrieve correct absolute scope for embed token wrapper
with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name:
pass
embed_tokens = _NoLayerEmbedTokens(self.shared, abs_scope_name=shared_abs_scope_name)
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
self.encoder.set_embed_tokens(embed_tokens)
self.decoder.set_embed_tokens(embed_tokens)

View File

@ -1065,3 +1065,33 @@ def cast_bool_to_primitive(bool_variable: Union[tf.Tensor, bool], default_tensor
# else variable is bool
return bool_variable
class TFWrappedEmbeddings:
"""
this class wraps a the TFSharedEmbeddingTokens layer into a python 'no-keras-layer'
class to avoid problem with weight restoring. Also it makes sure that the layer is
called from the correct scope to avoid problem with saving/storing the correct weights
"""
def __init__(self, layer, abs_scope_name=None):
self._layer = layer
self._abs_scope_name = abs_scope_name
def call(self, inputs, mode="embedding"):
if self._abs_scope_name is None:
return self._layer.call(inputs, mode)
# if an abs scope name is given to the embedding variable, call variable from absolute scope
with tf.compat.v1.variable_scope(self._abs_scope_name, auxiliary_name_scope=False) as abs_scope_name:
with tf.name_scope(abs_scope_name.original_name_scope):
return self._layer.call(inputs, mode)
def __call__(self, inputs, mode="embedding"):
if self._abs_scope_name is None:
return self._layer(inputs, mode)
# if an abs scope name is given to the embedding variable, call variable from absolute scope
with tf.compat.v1.variable_scope(self._abs_scope_name, auxiliary_name_scope=False) as abs_scope_name:
with tf.name_scope(abs_scope_name.original_name_scope):
return self._layer(inputs, mode)