Proper build() methods for TF (#27794)

* Add a convenience method for building in your own name scope

* Second attempt at auto layer building

* Revert "Second attempt at auto layer building"

This reverts commit e03a3aaecf9ec41a805582b83cbdfe3290a631be.

* Attempt #3

* Revert "Attempt #3"

This reverts commit b9df7a0857560d29b5abbed6127d9e9eca77cf47.

* Add missing attributes that we're going to need later

* Add some attributes we're going to need later

* A fourth attempt! Feel the power flow through you!

* Revert "A fourth attempt! Feel the power flow through you!"

This reverts commit 6bf4aaf3875d6f28485f50187617a4c616c8aff7.

* Add more values we'll need later

* TF refactor that we'll need later

* Revert "TF refactor that we'll need later"

This reverts commit ca07202fb5b7b7436b893baa8d688b4f348ea7b9.

* Revert "Revert "TF refactor that we'll need later""

This reverts commit 1beb0f39f293ed9c27594575e1c849aadeb15c13.

* make fixup

* Attempt five!

* Revert "Attempt five!"

This reverts commit 3302207958dfd0374b0447a51c06eea51a506044.

* Attempt six - this time don't add empty methods

* Revert "Attempt six - this time don't add empty methods"

This reverts commit 67d60129be75416b6beb8f47c7d38d77b18d79bb.

* Attempt seven - better base model class detection!

* Revert "Attempt seven - better base model class detection!"

This reverts commit 5f14845e92ea0e87c598da933bfbfee10f553bc9.

* Another attribute we'll need later

* Try again with the missing attribute!

* Revert "Try again with the missing attribute!"

This reverts commit 760c6f30c5dffb3e04b0e73c34a77d1882a0fef7.

* This is the attempt that will pierce the heavens!

* Revert "This is the attempt that will pierce the heavens!"

This reverts commit c868bb657de057aca7a5260350a3f831fc4dfee6.

* Attempt seven - snag list is steadily decreasing

* Revert "Attempt seven - snag list is steadily decreasing"

This reverts commit 46fbd975deda64429bfb3e5fac4fc0370c00d316.

* Attempt eight - will an empty snag list do it?

* Revert "Attempt eight - will an empty snag list do it?"

This reverts commit 7c8a3c2b083253649569e9877e02054ae5cec67b.

* Fixes to Hubert issues that cause problems later

* Trying again with Conv1D/SeparableConv fixes

* Revert "Trying again with Conv1D/SeparableConv fixes"

This reverts commit 55092bca952bc0f750aa1ffe246a640bf1e2036e.

* Apply the build shape fixes to Wav2Vec2 as well

* One more attempt!

* Revert "One more attempt!"

This reverts commit 5ac3e4cb01b9458cc93312873725f9444ae7261c.

* Another attempt!

* Revert "Another attempt!"

This reverts commit ea16d890e019d7de8792a3b8e72f3b1c02adae50.

* Let's see how many failures we get without the internal build method

* Fix OpenAI

* Fix MobileBERT

* (Mostly) fix GroupVIT

* Fix BLIP

* One more BLIP fix

* One more BLIP fix!

* Fix Regnet

* Finally fully fix GroupViT

* Fix Data2Vec and add the new AdaptivePool

* Fix Segformer

* Fix Albert

* Fix Deberta/DebertaV2

* Fix XLM

* Actually fix XLM

* Fix Flaubert

* Fix lxmert

* Fix Resnet

* Fix ConvBERT

* Fix ESM

* Fix Convnext / ConvnextV2

* Fix SAM

* Fix Efficientformer

* Fix LayoutLMv3

* Fix speech_to_text

* Fix mpnet and mobilevit

* Fix Swin

* Fix CTRL

* Fix CVT

* Fix DPR

* Fix Wav2Vec2

* Fix T5

* Fix Hubert

* Fix GPT2

* Fix Whisper

* Fix DeiT

* Fix the encoder-decoder / dual-encoder classes

* make fix-copies

* build in name scope

* Fix summarization test

* Fix tied weight names for BART + Blenderbot

* Fix tied weight name building

* Fix to TFESM weight building

* Update TF SAM

* Expand all the shapes out into Big Boy Shapes
This commit is contained in:
Matt 2023-12-14 15:17:30 +00:00 committed by GitHub
parent 52c37882fb
commit 050e0b44f6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
73 changed files with 11040 additions and 504 deletions

View File

@ -35,7 +35,6 @@ import tensorflow as tf
from huggingface_hub import Repository, list_repo_files
from keras import backend as K
from packaging.version import parse
from tensorflow.python.util.keras_deps import get_call_context_function
from . import DataCollatorWithPadding, DefaultDataCollator
from .activations_tf import get_tf_activation
@ -1122,6 +1121,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
)
return dummies
def build_in_name_scope(self):
with tf.name_scope(self.name):
self.build(input_shape=None)
@property
def framework(self) -> str:
"""
@ -1130,15 +1133,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
return "tf"
def build(self, input_shape=None):
call_context = get_call_context_function()
if self.built or call_context().in_call:
self.built = True
else:
self.built = True
# Set the serving spec quickly to ensure that Keras doesn't use the specific dummy input shapes as the spec
# Setting it in build() allows users to override the shape when loading a non-pretrained model from config
self._set_save_spec(self.input_signature)
self(self.dummy_inputs, training=False)
pass # This is just here to make sure we don't call the superclass build()
def __init__(self, config, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
@ -1869,7 +1864,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
main_layer.set_input_embeddings(value)
except AttributeError:
logger.info("Building the model")
self.build()
self.build_in_name_scope()
main_layer.set_input_embeddings(value)
def get_output_embeddings(self) -> Union[None, tf.keras.layers.Layer]:
@ -1886,7 +1881,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
return lm_head.get_output_embeddings()
except AttributeError:
logger.info("Building the model")
self.build()
self.build_in_name_scope()
return lm_head().get_output_embeddings()
@ -1906,7 +1901,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
lm_head.set_output_embeddings(value)
except AttributeError:
logger.info("Building the model")
self.build()
self.build_in_name_scope()
lm_head.set_output_embeddings(value)
def get_output_layer_with_bias(self) -> Union[None, tf.keras.layers.Layer]:
@ -1944,7 +1939,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
try:
return lm_head.get_bias()
except AttributeError:
self.build()
self.build_in_name_scope()
return lm_head.get_bias()
return None
@ -1962,7 +1957,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
try:
lm_head.set_bias(value)
except AttributeError:
self.build()
self.build_in_name_scope()
lm_head.set_bias(value)
def get_lm_head(self) -> tf.keras.layers.Layer:
@ -2049,7 +2044,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
# The reason why the attributes don't exist might be
# because the model is not built, so retry getting
# the argument after building the model
model.build()
model.build_in_name_scope()
embeds = getattr(embedding_layer, "weight", None)
if embeds is not None:
@ -2914,9 +2909,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
# we might need to extend the variable scope for composite models
if load_weight_prefix is not None:
with tf.compat.v1.variable_scope(load_weight_prefix):
model.build() # build the network with dummy inputs
model.build_in_name_scope() # build the network with dummy inputs
else:
model.build() # build the network with dummy inputs
model.build_in_name_scope() # build the network with dummy inputs
if safetensors_from_pt:
from .modeling_tf_pytorch_utils import load_pytorch_state_dict_in_tf2_model
@ -3215,6 +3210,9 @@ class TFConv1D(tf.keras.layers.Layer):
self.initializer_range = initializer_range
def build(self, input_shape):
if self.built:
return
self.built = True
self.weight = self.add_weight(
"weight", shape=[self.nx, self.nf], initializer=get_initializer(self.initializer_range)
)
@ -3398,6 +3396,7 @@ class TFSequenceSummary(tf.keras.layers.Layer):
self.has_last_dropout = hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0
if self.has_last_dropout:
self.last_dropout = tf.keras.layers.Dropout(config.summary_last_dropout)
self.hidden_size = config.hidden_size
def call(self, inputs, cls_index=None, training=False):
if not isinstance(inputs, (dict, tuple, list)):
@ -3450,6 +3449,14 @@ class TFSequenceSummary(tf.keras.layers.Layer):
return output
def build(self, input_shape):
if self.built:
return
self.built = True
if getattr(self, "summary", None) is not None:
with tf.name_scope("summary"):
self.summary.build(self.hidden_size)
def get_initializer(initializer_range: float = 0.02) -> tf.keras.initializers.TruncatedNormal:
"""

View File

@ -146,7 +146,7 @@ class TFAlbertEmbeddings(tf.keras.layers.Layer):
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
def build(self, input_shape: tf.TensorShape):
def build(self, input_shape=None):
with tf.name_scope("word_embeddings"):
self.weight = self.add_weight(
name="weight",
@ -168,7 +168,12 @@ class TFAlbertEmbeddings(tf.keras.layers.Layer):
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.embedding_size])
# Copied from transformers.models.bert.modeling_tf_bert.TFBertEmbeddings.call
def call(
@ -246,6 +251,7 @@ class TFAlbertAttention(tf.keras.layers.Layer):
# Two different dropout probabilities; see https://github.com/google-research/albert/blob/master/modeling.py#L971-L993
self.attention_dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
self.output_dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
self.config = config
def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
@ -307,6 +313,26 @@ class TFAlbertAttention(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "query", None) is not None:
with tf.name_scope(self.query.name):
self.query.build([None, None, self.config.hidden_size])
if getattr(self, "key", None) is not None:
with tf.name_scope(self.key.name):
self.key.build([None, None, self.config.hidden_size])
if getattr(self, "value", None) is not None:
with tf.name_scope(self.value.name):
self.value.build([None, None, self.config.hidden_size])
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
class TFAlbertLayer(tf.keras.layers.Layer):
def __init__(self, config: AlbertConfig, **kwargs):
@ -329,6 +355,7 @@ class TFAlbertLayer(tf.keras.layers.Layer):
epsilon=config.layer_norm_eps, name="full_layer_layer_norm"
)
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
self.config = config
def call(
self,
@ -356,6 +383,23 @@ class TFAlbertLayer(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "attention", None) is not None:
with tf.name_scope(self.attention.name):
self.attention.build(None)
if getattr(self, "ffn", None) is not None:
with tf.name_scope(self.ffn.name):
self.ffn.build([None, None, self.config.hidden_size])
if getattr(self, "ffn_output", None) is not None:
with tf.name_scope(self.ffn_output.name):
self.ffn_output.build([None, None, self.config.intermediate_size])
if getattr(self, "full_layer_layer_norm", None) is not None:
with tf.name_scope(self.full_layer_layer_norm.name):
self.full_layer_layer_norm.build([None, None, self.config.hidden_size])
class TFAlbertLayerGroup(tf.keras.layers.Layer):
def __init__(self, config: AlbertConfig, **kwargs):
@ -399,6 +443,15 @@ class TFAlbertLayerGroup(tf.keras.layers.Layer):
return tuple(v for v in [hidden_states, layer_hidden_states, layer_attentions] if v is not None)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "albert_layers", None) is not None:
for layer in self.albert_layers:
with tf.name_scope(layer.name):
layer.build(None)
class TFAlbertTransformer(tf.keras.layers.Layer):
def __init__(self, config: AlbertConfig, **kwargs):
@ -416,6 +469,7 @@ class TFAlbertTransformer(tf.keras.layers.Layer):
self.albert_layer_groups = [
TFAlbertLayerGroup(config, name=f"albert_layer_groups_._{i}") for i in range(config.num_hidden_groups)
]
self.config = config
def call(
self,
@ -457,6 +511,18 @@ class TFAlbertTransformer(tf.keras.layers.Layer):
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embedding_hidden_mapping_in", None) is not None:
with tf.name_scope(self.embedding_hidden_mapping_in.name):
self.embedding_hidden_mapping_in.build([None, None, self.config.embedding_size])
if getattr(self, "albert_layer_groups", None) is not None:
for layer in self.albert_layer_groups:
with tf.name_scope(layer.name):
layer.build(None)
class TFAlbertPreTrainedModel(TFPreTrainedModel):
"""
@ -488,13 +554,21 @@ class TFAlbertMLMHead(tf.keras.layers.Layer):
# an output-only bias for each token.
self.decoder = input_embeddings
def build(self, input_shape: tf.TensorShape):
def build(self, input_shape=None):
self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias")
self.decoder_bias = self.add_weight(
shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="decoder/bias"
)
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.embedding_size])
def get_output_embeddings(self) -> tf.keras.layers.Layer:
return self.decoder
@ -650,6 +724,20 @@ class TFAlbertMainLayer(tf.keras.layers.Layer):
attentions=encoder_outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embeddings", None) is not None:
with tf.name_scope(self.embeddings.name):
self.embeddings.build(None)
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
if getattr(self, "pooler", None) is not None:
with tf.name_scope(self.pooler.name):
self.pooler.build([None, None, self.config.hidden_size])
@dataclass
class TFAlbertForPreTrainingOutput(ModelOutput):
@ -825,6 +913,14 @@ class TFAlbertModel(TFAlbertPreTrainedModel):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "albert", None) is not None:
with tf.name_scope(self.albert.name):
self.albert.build(None)
@add_start_docstrings(
"""
@ -921,6 +1017,20 @@ class TFAlbertForPreTraining(TFAlbertPreTrainedModel, TFAlbertPreTrainingLoss):
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "albert", None) is not None:
with tf.name_scope(self.albert.name):
self.albert.build(None)
if getattr(self, "predictions", None) is not None:
with tf.name_scope(self.predictions.name):
self.predictions.build(None)
if getattr(self, "sop_classifier", None) is not None:
with tf.name_scope(self.sop_classifier.name):
self.sop_classifier.build(None)
class TFAlbertSOPHead(tf.keras.layers.Layer):
def __init__(self, config: AlbertConfig, **kwargs):
@ -932,6 +1042,7 @@ class TFAlbertSOPHead(tf.keras.layers.Layer):
kernel_initializer=get_initializer(config.initializer_range),
name="classifier",
)
self.config = config
def call(self, pooled_output: tf.Tensor, training: bool) -> tf.Tensor:
dropout_pooled_output = self.dropout(inputs=pooled_output, training=training)
@ -939,6 +1050,14 @@ class TFAlbertSOPHead(tf.keras.layers.Layer):
return logits
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.hidden_size])
@add_start_docstrings("""Albert Model with a `language modeling` head on top.""", ALBERT_START_DOCSTRING)
class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss):
@ -1035,6 +1154,17 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss)
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "albert", None) is not None:
with tf.name_scope(self.albert.name):
self.albert.build(None)
if getattr(self, "predictions", None) is not None:
with tf.name_scope(self.predictions.name):
self.predictions.build(None)
@add_start_docstrings(
"""
@ -1058,6 +1188,7 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClass
self.classifier = tf.keras.layers.Dense(
units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ -1117,6 +1248,17 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClass
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "albert", None) is not None:
with tf.name_scope(self.albert.name):
self.albert.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.hidden_size])
@add_start_docstrings(
"""
@ -1145,6 +1287,7 @@ class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificat
self.classifier = tf.keras.layers.Dense(
units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ -1200,6 +1343,17 @@ class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificat
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "albert", None) is not None:
with tf.name_scope(self.albert.name):
self.albert.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.hidden_size])
@add_start_docstrings(
"""
@ -1221,6 +1375,7 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringL
self.qa_outputs = tf.keras.layers.Dense(
units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ -1295,6 +1450,17 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringL
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "albert", None) is not None:
with tf.name_scope(self.albert.name):
self.albert.build(None)
if getattr(self, "qa_outputs", None) is not None:
with tf.name_scope(self.qa_outputs.name):
self.qa_outputs.build([None, None, self.config.hidden_size])
@add_start_docstrings(
"""
@ -1316,6 +1482,7 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
self.classifier = tf.keras.layers.Dense(
units=1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
@ -1394,3 +1561,14 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "albert", None) is not None:
with tf.name_scope(self.albert.name):
self.albert.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.hidden_size])

View File

@ -43,7 +43,6 @@ from ...modeling_tf_utils import (
)
from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
from ...utils import (
ContextManagers,
add_code_sample_docstrings,
add_end_docstrings,
add_start_docstrings,
@ -296,6 +295,23 @@ class TFBartAttention(tf.keras.layers.Layer):
return attn_output, attn_weights, past_key_value
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "k_proj", None) is not None:
with tf.name_scope(self.k_proj.name):
self.k_proj.build([None, None, self.embed_dim])
if getattr(self, "q_proj", None) is not None:
with tf.name_scope(self.q_proj.name):
self.q_proj.build([None, None, self.embed_dim])
if getattr(self, "v_proj", None) is not None:
with tf.name_scope(self.v_proj.name):
self.v_proj.build([None, None, self.embed_dim])
if getattr(self, "out_proj", None) is not None:
with tf.name_scope(self.out_proj.name):
self.out_proj.build([None, None, self.embed_dim])
class TFBartEncoderLayer(tf.keras.layers.Layer):
def __init__(self, config: BartConfig, **kwargs):
@ -311,6 +327,7 @@ class TFBartEncoderLayer(tf.keras.layers.Layer):
self.fc1 = tf.keras.layers.Dense(config.encoder_ffn_dim, name="fc1")
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
self.config = config
def call(
self,
@ -352,6 +369,26 @@ class TFBartEncoderLayer(tf.keras.layers.Layer):
return hidden_states, self_attn_weights
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "self_attn", None) is not None:
with tf.name_scope(self.self_attn.name):
self.self_attn.build(None)
if getattr(self, "self_attn_layer_norm", None) is not None:
with tf.name_scope(self.self_attn_layer_norm.name):
self.self_attn_layer_norm.build([None, None, self.embed_dim])
if getattr(self, "fc1", None) is not None:
with tf.name_scope(self.fc1.name):
self.fc1.build([None, None, self.embed_dim])
if getattr(self, "fc2", None) is not None:
with tf.name_scope(self.fc2.name):
self.fc2.build([None, None, self.config.encoder_ffn_dim])
if getattr(self, "final_layer_norm", None) is not None:
with tf.name_scope(self.final_layer_norm.name):
self.final_layer_norm.build([None, None, self.embed_dim])
class TFBartDecoderLayer(tf.keras.layers.Layer):
def __init__(self, config: BartConfig, **kwargs):
@ -380,6 +417,7 @@ class TFBartDecoderLayer(tf.keras.layers.Layer):
self.fc1 = tf.keras.layers.Dense(config.decoder_ffn_dim, name="fc1")
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
self.config = config
def call(
self,
@ -461,6 +499,32 @@ class TFBartDecoderLayer(tf.keras.layers.Layer):
present_key_value,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "self_attn", None) is not None:
with tf.name_scope(self.self_attn.name):
self.self_attn.build(None)
if getattr(self, "self_attn_layer_norm", None) is not None:
with tf.name_scope(self.self_attn_layer_norm.name):
self.self_attn_layer_norm.build([None, None, self.embed_dim])
if getattr(self, "encoder_attn", None) is not None:
with tf.name_scope(self.encoder_attn.name):
self.encoder_attn.build(None)
if getattr(self, "encoder_attn_layer_norm", None) is not None:
with tf.name_scope(self.encoder_attn_layer_norm.name):
self.encoder_attn_layer_norm.build([None, None, self.embed_dim])
if getattr(self, "fc1", None) is not None:
with tf.name_scope(self.fc1.name):
self.fc1.build([None, None, self.embed_dim])
if getattr(self, "fc2", None) is not None:
with tf.name_scope(self.fc2.name):
self.fc2.build([None, None, self.config.decoder_ffn_dim])
if getattr(self, "final_layer_norm", None) is not None:
with tf.name_scope(self.final_layer_norm.name):
self.final_layer_norm.build([None, None, self.embed_dim])
class TFBartClassificationHead(tf.keras.layers.Layer):
"""Head for sentence-level classification tasks."""
@ -470,6 +534,8 @@ class TFBartClassificationHead(tf.keras.layers.Layer):
self.dense = tf.keras.layers.Dense(inner_dim, name="dense")
self.dropout = tf.keras.layers.Dropout(pooler_dropout)
self.out_proj = tf.keras.layers.Dense(num_classes, name="out_proj")
self.input_dim = inner_dim
self.inner_dim = inner_dim
def call(self, inputs):
hidden_states = self.dropout(inputs)
@ -479,6 +545,17 @@ class TFBartClassificationHead(tf.keras.layers.Layer):
hidden_states = self.out_proj(hidden_states)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.input_dim])
if getattr(self, "out_proj", None) is not None:
with tf.name_scope(self.out_proj.name):
self.out_proj.build([None, None, self.inner_dim])
class TFBartPretrainedModel(TFPreTrainedModel):
config_class = BartConfig
@ -686,6 +763,7 @@ class TFBartEncoder(tf.keras.layers.Layer):
)
self.layers = [TFBartEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)]
self.layernorm_embedding = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding")
self.embed_dim = config.d_model
@unpack_inputs
def call(
@ -745,16 +823,8 @@ class TFBartEncoder(tf.keras.layers.Layer):
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None:
# if `self.embed_tokens.load_weight_prefix` is set, runs the embedding operation with the correct name
# scope, so that its weights are registered with the desired name for loading/storing. When `tf.name_scope`
# is used with a name ending in `/`, that name replaces the current name scope.
# (embeddings with tf.name_scope: self.embed_tokens.load_weight_prefix/self.embed_tokens.name/embeddings:0)
context = []
if hasattr(self.embed_tokens, "load_weight_prefix"):
context.append(tf.name_scope(self.embed_tokens.load_weight_prefix + "/"))
with ContextManagers(context):
check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim)
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim)
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
embed_pos = self.embed_positions(input_shape)
hidden_states = inputs_embeds + embed_pos
@ -809,6 +879,21 @@ class TFBartEncoder(tf.keras.layers.Layer):
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embed_positions", None) is not None:
with tf.name_scope(self.embed_positions.name):
self.embed_positions.build(None)
if getattr(self, "layernorm_embedding", None) is not None:
with tf.name_scope(self.layernorm_embedding.name):
self.layernorm_embedding.build([None, None, self.embed_dim])
if getattr(self, "layers", None) is not None:
for layer in self.layers:
with tf.name_scope(layer.name):
layer.build(None)
@keras_serializable
class TFBartDecoder(tf.keras.layers.Layer):
@ -938,16 +1023,8 @@ class TFBartDecoder(tf.keras.layers.Layer):
positions = self.embed_positions(input_shape, position_ids=position_ids)
if inputs_embeds is None:
# if `self.embed_tokens.load_weight_prefix` is set, runs the embedding operation with the correct name
# scope, so that its weights are registered with the desired name for loading/storing. When `tf.name_scope`
# is used with a name ending in `/`, that name replaces the current name scope.
# (embeddings with tf.name_scope: self.embed_tokens.load_weight_prefix/self.embed_tokens.name/embeddings:0)
context = []
if hasattr(self.embed_tokens, "load_weight_prefix"):
context.append(tf.name_scope(self.embed_tokens.load_weight_prefix + "/"))
with ContextManagers(context):
check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim)
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim)
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
hidden_states = inputs_embeds
@ -1032,6 +1109,21 @@ class TFBartDecoder(tf.keras.layers.Layer):
cross_attentions=all_cross_attns,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embed_positions", None) is not None:
with tf.name_scope(self.embed_positions.name):
self.embed_positions.build(None)
if getattr(self, "layernorm_embedding", None) is not None:
with tf.name_scope(self.layernorm_embedding.name):
self.layernorm_embedding.build([None, None, self.config.d_model])
if getattr(self, "layers", None) is not None:
for layer in self.layers:
with tf.name_scope(layer.name):
layer.build(None)
@keras_serializable
class TFBartMainLayer(tf.keras.layers.Layer):
@ -1149,6 +1241,22 @@ class TFBartMainLayer(tf.keras.layers.Layer):
encoder_attentions=encoder_outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
# The shared/tied weights expect to be in the model base namespace
# Adding "/" to the end (not the start!) of a tf.name_scope puts it in the root namespace rather than
# the current one.
with tf.name_scope(self.shared.load_weight_prefix + "/" + self.shared.name + "/"):
self.shared.build(None)
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
if getattr(self, "decoder", None) is not None:
with tf.name_scope(self.decoder.name):
self.decoder.build(None)
@add_start_docstrings(
"The bare BART Model outputting raw hidden-states without any specific head on top.",
@ -1237,6 +1345,14 @@ class TFBartModel(TFBartPretrainedModel):
encoder_attentions=enc_attns,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "model", None) is not None:
with tf.name_scope(self.model.name):
self.model.build(None)
class BiasLayer(tf.keras.layers.Layer):
"""
@ -1440,6 +1556,17 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "model", None) is not None:
with tf.name_scope(self.model.name):
self.model.build(None)
if getattr(self, "bias_layer", None) is not None:
with tf.name_scope(self.bias_layer.name):
self.bias_layer.build(None)
@add_start_docstrings(
"""
@ -1567,3 +1694,14 @@ class TFBartForSequenceClassification(TFBartPretrainedModel, TFSequenceClassific
encoder_hidden_states=enc_hs,
encoder_attentions=enc_attns,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "model", None) is not None:
with tf.name_scope(self.model.name):
self.model.build(None)
if getattr(self, "classification_head", None) is not None:
with tf.name_scope(self.classification_head.name):
self.classification_head.build(None)

View File

@ -156,7 +156,7 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
def build(self, input_shape: tf.TensorShape):
def build(self, input_shape=None):
with tf.name_scope("word_embeddings"):
self.weight = self.add_weight(
name="weight",
@ -178,7 +178,12 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
def call(
self,
@ -248,6 +253,7 @@ class TFBertSelfAttention(tf.keras.layers.Layer):
self.dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
self.is_decoder = config.is_decoder
self.config = config
def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
@ -337,6 +343,20 @@ class TFBertSelfAttention(tf.keras.layers.Layer):
outputs = outputs + (past_key_value,)
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "query", None) is not None:
with tf.name_scope(self.query.name):
self.query.build([None, None, self.config.hidden_size])
if getattr(self, "key", None) is not None:
with tf.name_scope(self.key.name):
self.key.build([None, None, self.config.hidden_size])
if getattr(self, "value", None) is not None:
with tf.name_scope(self.value.name):
self.value.build([None, None, self.config.hidden_size])
class TFBertSelfOutput(tf.keras.layers.Layer):
def __init__(self, config: BertConfig, **kwargs):
@ -347,6 +367,7 @@ class TFBertSelfOutput(tf.keras.layers.Layer):
)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
self.config = config
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -355,6 +376,17 @@ class TFBertSelfOutput(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
class TFBertAttention(tf.keras.layers.Layer):
def __init__(self, config: BertConfig, **kwargs):
@ -395,6 +427,17 @@ class TFBertAttention(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "self_attention", None) is not None:
with tf.name_scope(self.self_attention.name):
self.self_attention.build(None)
if getattr(self, "dense_output", None) is not None:
with tf.name_scope(self.dense_output.name):
self.dense_output.build(None)
class TFBertIntermediate(tf.keras.layers.Layer):
def __init__(self, config: BertConfig, **kwargs):
@ -408,6 +451,7 @@ class TFBertIntermediate(tf.keras.layers.Layer):
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
else:
self.intermediate_act_fn = config.hidden_act
self.config = config
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -415,6 +459,14 @@ class TFBertIntermediate(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
class TFBertOutput(tf.keras.layers.Layer):
def __init__(self, config: BertConfig, **kwargs):
@ -425,6 +477,7 @@ class TFBertOutput(tf.keras.layers.Layer):
)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
self.config = config
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -433,6 +486,17 @@ class TFBertOutput(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.intermediate_size])
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
class TFBertLayer(tf.keras.layers.Layer):
def __init__(self, config: BertConfig, **kwargs):
@ -519,6 +583,23 @@ class TFBertLayer(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "attention", None) is not None:
with tf.name_scope(self.attention.name):
self.attention.build(None)
if getattr(self, "intermediate", None) is not None:
with tf.name_scope(self.intermediate.name):
self.intermediate.build(None)
if getattr(self, "bert_output", None) is not None:
with tf.name_scope(self.bert_output.name):
self.bert_output.build(None)
if getattr(self, "crossattention", None) is not None:
with tf.name_scope(self.crossattention.name):
self.crossattention.build(None)
class TFBertEncoder(tf.keras.layers.Layer):
def __init__(self, config: BertConfig, **kwargs):
@ -588,6 +669,15 @@ class TFBertEncoder(tf.keras.layers.Layer):
cross_attentions=all_cross_attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "layer", None) is not None:
for layer in self.layer:
with tf.name_scope(layer.name):
layer.build(None)
class TFBertPooler(tf.keras.layers.Layer):
def __init__(self, config: BertConfig, **kwargs):
@ -599,6 +689,7 @@ class TFBertPooler(tf.keras.layers.Layer):
activation="tanh",
name="dense",
)
self.config = config
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
# We "pool" the model by simply taking the hidden state corresponding
@ -608,6 +699,14 @@ class TFBertPooler(tf.keras.layers.Layer):
return pooled_output
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
class TFBertPredictionHeadTransform(tf.keras.layers.Layer):
def __init__(self, config: BertConfig, **kwargs):
@ -625,6 +724,7 @@ class TFBertPredictionHeadTransform(tf.keras.layers.Layer):
self.transform_act_fn = config.hidden_act
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.config = config
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -633,6 +733,17 @@ class TFBertPredictionHeadTransform(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
class TFBertLMPredictionHead(tf.keras.layers.Layer):
def __init__(self, config: BertConfig, input_embeddings: tf.keras.layers.Layer, **kwargs):
@ -647,10 +758,15 @@ class TFBertLMPredictionHead(tf.keras.layers.Layer):
# an output-only bias for each token.
self.input_embeddings = input_embeddings
def build(self, input_shape: tf.TensorShape):
def build(self, input_shape=None):
self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias")
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "transform", None) is not None:
with tf.name_scope(self.transform.name):
self.transform.build(None)
def get_output_embeddings(self) -> tf.keras.layers.Layer:
return self.input_embeddings
@ -688,6 +804,14 @@ class TFBertMLMHead(tf.keras.layers.Layer):
return prediction_scores
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "predictions", None) is not None:
with tf.name_scope(self.predictions.name):
self.predictions.build(None)
class TFBertNSPHead(tf.keras.layers.Layer):
def __init__(self, config: BertConfig, **kwargs):
@ -698,12 +822,21 @@ class TFBertNSPHead(tf.keras.layers.Layer):
kernel_initializer=get_initializer(config.initializer_range),
name="seq_relationship",
)
self.config = config
def call(self, pooled_output: tf.Tensor) -> tf.Tensor:
seq_relationship_score = self.seq_relationship(inputs=pooled_output)
return seq_relationship_score
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "seq_relationship", None) is not None:
with tf.name_scope(self.seq_relationship.name):
self.seq_relationship.build([None, None, self.config.hidden_size])
@keras_serializable
class TFBertMainLayer(tf.keras.layers.Layer):
@ -891,6 +1024,20 @@ class TFBertMainLayer(tf.keras.layers.Layer):
cross_attentions=encoder_outputs.cross_attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embeddings", None) is not None:
with tf.name_scope(self.embeddings.name):
self.embeddings.build(None)
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
if getattr(self, "pooler", None) is not None:
with tf.name_scope(self.pooler.name):
self.pooler.build(None)
class TFBertPreTrainedModel(TFPreTrainedModel):
"""
@ -1103,6 +1250,14 @@ class TFBertModel(TFBertPreTrainedModel):
)
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "bert", None) is not None:
with tf.name_scope(self.bert.name):
self.bert.build(None)
@add_start_docstrings(
"""
@ -1215,6 +1370,20 @@ class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss):
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "bert", None) is not None:
with tf.name_scope(self.bert.name):
self.bert.build(None)
if getattr(self, "nsp", None) is not None:
with tf.name_scope(self.nsp.name):
self.nsp.build(None)
if getattr(self, "mlm", None) is not None:
with tf.name_scope(self.mlm.name):
self.mlm.build(None)
@add_start_docstrings("""Bert Model with a `language modeling` head on top.""", BERT_START_DOCSTRING)
class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
@ -1301,6 +1470,17 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "bert", None) is not None:
with tf.name_scope(self.bert.name):
self.bert.build(None)
if getattr(self, "mlm", None) is not None:
with tf.name_scope(self.mlm.name):
self.mlm.build(None)
class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
@ -1426,6 +1606,17 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
cross_attentions=outputs.cross_attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "bert", None) is not None:
with tf.name_scope(self.bert.name):
self.bert.build(None)
if getattr(self, "mlm", None) is not None:
with tf.name_scope(self.mlm.name):
self.mlm.build(None)
@add_start_docstrings(
"""Bert Model with a `next sentence prediction (classification)` head on top.""",
@ -1508,6 +1699,17 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel, TFNextSentencePredi
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "bert", None) is not None:
with tf.name_scope(self.bert.name):
self.bert.build(None)
if getattr(self, "nsp", None) is not None:
with tf.name_scope(self.nsp.name):
self.nsp.build(None)
@add_start_docstrings(
"""
@ -1536,6 +1738,7 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
kernel_initializer=get_initializer(config.initializer_range),
name="classifier",
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ -1594,6 +1797,17 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "bert", None) is not None:
with tf.name_scope(self.bert.name):
self.bert.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.hidden_size])
@add_start_docstrings(
"""
@ -1615,6 +1829,7 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
self.classifier = tf.keras.layers.Dense(
units=1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
@ -1693,6 +1908,17 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "bert", None) is not None:
with tf.name_scope(self.bert.name):
self.bert.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.hidden_size])
@add_start_docstrings(
"""
@ -1727,6 +1953,7 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
kernel_initializer=get_initializer(config.initializer_range),
name="classifier",
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ -1783,6 +2010,17 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "bert", None) is not None:
with tf.name_scope(self.bert.name):
self.bert.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.hidden_size])
@add_start_docstrings(
"""
@ -1812,6 +2050,7 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss)
kernel_initializer=get_initializer(config.initializer_range),
name="qa_outputs",
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ -1884,3 +2123,14 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss)
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "bert", None) is not None:
with tf.name_scope(self.bert.name):
self.bert.build(None)
if getattr(self, "qa_outputs", None) is not None:
with tf.name_scope(self.qa_outputs.name):
self.qa_outputs.build([None, None, self.config.hidden_size])

View File

@ -41,7 +41,6 @@ from ...modeling_tf_utils import (
)
from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
from ...utils import (
ContextManagers,
add_code_sample_docstrings,
add_end_docstrings,
add_start_docstrings,
@ -291,6 +290,23 @@ class TFBlenderbotAttention(tf.keras.layers.Layer):
return attn_output, attn_weights, past_key_value
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "k_proj", None) is not None:
with tf.name_scope(self.k_proj.name):
self.k_proj.build([None, None, self.embed_dim])
if getattr(self, "q_proj", None) is not None:
with tf.name_scope(self.q_proj.name):
self.q_proj.build([None, None, self.embed_dim])
if getattr(self, "v_proj", None) is not None:
with tf.name_scope(self.v_proj.name):
self.v_proj.build([None, None, self.embed_dim])
if getattr(self, "out_proj", None) is not None:
with tf.name_scope(self.out_proj.name):
self.out_proj.build([None, None, self.embed_dim])
# Copied from transformers.models.mbart.modeling_tf_mbart.TFMBartEncoderLayer with MBart->Blenderbot
class TFBlenderbotEncoderLayer(tf.keras.layers.Layer):
@ -307,6 +323,7 @@ class TFBlenderbotEncoderLayer(tf.keras.layers.Layer):
self.fc1 = tf.keras.layers.Dense(config.encoder_ffn_dim, name="fc1")
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
self.config = config
def call(
self,
@ -348,6 +365,26 @@ class TFBlenderbotEncoderLayer(tf.keras.layers.Layer):
return hidden_states, self_attn_weights
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "self_attn", None) is not None:
with tf.name_scope(self.self_attn.name):
self.self_attn.build(None)
if getattr(self, "self_attn_layer_norm", None) is not None:
with tf.name_scope(self.self_attn_layer_norm.name):
self.self_attn_layer_norm.build([None, None, self.embed_dim])
if getattr(self, "fc1", None) is not None:
with tf.name_scope(self.fc1.name):
self.fc1.build([None, None, self.embed_dim])
if getattr(self, "fc2", None) is not None:
with tf.name_scope(self.fc2.name):
self.fc2.build([None, None, self.config.encoder_ffn_dim])
if getattr(self, "final_layer_norm", None) is not None:
with tf.name_scope(self.final_layer_norm.name):
self.final_layer_norm.build([None, None, self.embed_dim])
# Copied from transformers.models.mbart.modeling_tf_mbart.TFMBartDecoderLayer with MBart->Blenderbot
class TFBlenderbotDecoderLayer(tf.keras.layers.Layer):
@ -377,6 +414,7 @@ class TFBlenderbotDecoderLayer(tf.keras.layers.Layer):
self.fc1 = tf.keras.layers.Dense(config.decoder_ffn_dim, name="fc1")
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
self.config = config
def call(
self,
@ -458,6 +496,32 @@ class TFBlenderbotDecoderLayer(tf.keras.layers.Layer):
present_key_value,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "self_attn", None) is not None:
with tf.name_scope(self.self_attn.name):
self.self_attn.build(None)
if getattr(self, "self_attn_layer_norm", None) is not None:
with tf.name_scope(self.self_attn_layer_norm.name):
self.self_attn_layer_norm.build([None, None, self.embed_dim])
if getattr(self, "encoder_attn", None) is not None:
with tf.name_scope(self.encoder_attn.name):
self.encoder_attn.build(None)
if getattr(self, "encoder_attn_layer_norm", None) is not None:
with tf.name_scope(self.encoder_attn_layer_norm.name):
self.encoder_attn_layer_norm.build([None, None, self.embed_dim])
if getattr(self, "fc1", None) is not None:
with tf.name_scope(self.fc1.name):
self.fc1.build([None, None, self.embed_dim])
if getattr(self, "fc2", None) is not None:
with tf.name_scope(self.fc2.name):
self.fc2.build([None, None, self.config.decoder_ffn_dim])
if getattr(self, "final_layer_norm", None) is not None:
with tf.name_scope(self.final_layer_norm.name):
self.final_layer_norm.build([None, None, self.embed_dim])
class TFBlenderbotPreTrainedModel(TFPreTrainedModel):
config_class = BlenderbotConfig
@ -711,16 +775,8 @@ class TFBlenderbotEncoder(tf.keras.layers.Layer):
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None:
# if `self.embed_tokens.load_weight_prefix` is set, runs the embedding operation with the correct name
# scope, so that its weights are registered with the desired name for loading/storing. When `tf.name_scope`
# is used with a name ending in `/`, that name replaces the current name scope.
# (embeddings with tf.name_scope: self.embed_tokens.load_weight_prefix/self.embed_tokens.name/embeddings:0)
context = []
if hasattr(self.embed_tokens, "load_weight_prefix"):
context.append(tf.name_scope(self.embed_tokens.load_weight_prefix + "/"))
with ContextManagers(context):
check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim)
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim)
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
embed_pos = self.embed_positions(input_shape)
hidden_states = inputs_embeds + embed_pos
@ -776,6 +832,21 @@ class TFBlenderbotEncoder(tf.keras.layers.Layer):
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embed_positions", None) is not None:
with tf.name_scope(self.embed_positions.name):
self.embed_positions.build(None)
if getattr(self, "layer_norm", None) is not None:
with tf.name_scope(self.layer_norm.name):
self.layer_norm.build([None, None, self.config.d_model])
if getattr(self, "layers", None) is not None:
for layer in self.layers:
with tf.name_scope(layer.name):
layer.build(None)
@keras_serializable
class TFBlenderbotDecoder(tf.keras.layers.Layer):
@ -916,12 +987,8 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
positions = self.embed_positions(input_shape, position_ids=position_ids)
if inputs_embeds is None:
context = []
if hasattr(self.embed_tokens, "load_weight_prefix"):
context.append(tf.name_scope(self.embed_tokens.load_weight_prefix + "/"))
with ContextManagers(context):
check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim)
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim)
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
hidden_states = inputs_embeds
@ -1006,6 +1073,21 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
cross_attentions=all_cross_attns,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embed_positions", None) is not None:
with tf.name_scope(self.embed_positions.name):
self.embed_positions.build(None)
if getattr(self, "layer_norm", None) is not None:
with tf.name_scope(self.layer_norm.name):
self.layer_norm.build([None, None, self.config.d_model])
if getattr(self, "layers", None) is not None:
for layer in self.layers:
with tf.name_scope(layer.name):
layer.build(None)
@keras_serializable
class TFBlenderbotMainLayer(tf.keras.layers.Layer):
@ -1114,6 +1196,22 @@ class TFBlenderbotMainLayer(tf.keras.layers.Layer):
encoder_attentions=encoder_outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
# The shared/tied weights expect to be in the model base namespace
# Adding "/" to the end (not the start!) of a tf.name_scope puts it in the root namespace rather than
# the current one.
with tf.name_scope(self.shared.load_weight_prefix + "/" + self.shared.name + "/"):
self.shared.build(None)
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
if getattr(self, "decoder", None) is not None:
with tf.name_scope(self.decoder.name):
self.decoder.build(None)
@add_start_docstrings(
"The bare BLENDERBOT Model outputting raw hidden-states without any specific head on top.",
@ -1217,6 +1315,14 @@ class TFBlenderbotModel(TFBlenderbotPreTrainedModel):
encoder_attentions=enc_attns,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "model", None) is not None:
with tf.name_scope(self.model.name):
self.model.build(None)
# Copied from transformers.models.bart.modeling_tf_bart.BiasLayer
class BiasLayer(tf.keras.layers.Layer):
@ -1436,3 +1542,14 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal
"cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
}
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "model", None) is not None:
with tf.name_scope(self.model.name):
self.model.build(None)
if getattr(self, "bias_layer", None) is not None:
with tf.name_scope(self.bias_layer.name):
self.bias_layer.build(None)

View File

@ -40,7 +40,6 @@ from ...modeling_tf_utils import (
)
from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
from ...utils import (
ContextManagers,
add_code_sample_docstrings,
add_end_docstrings,
add_start_docstrings,
@ -291,6 +290,23 @@ class TFBlenderbotSmallAttention(tf.keras.layers.Layer):
return attn_output, attn_weights, past_key_value
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "k_proj", None) is not None:
with tf.name_scope(self.k_proj.name):
self.k_proj.build([None, None, self.embed_dim])
if getattr(self, "q_proj", None) is not None:
with tf.name_scope(self.q_proj.name):
self.q_proj.build([None, None, self.embed_dim])
if getattr(self, "v_proj", None) is not None:
with tf.name_scope(self.v_proj.name):
self.v_proj.build([None, None, self.embed_dim])
if getattr(self, "out_proj", None) is not None:
with tf.name_scope(self.out_proj.name):
self.out_proj.build([None, None, self.embed_dim])
# Copied from transformers.models.bart.modeling_tf_bart.TFBartEncoderLayer with Bart->BlenderbotSmall
class TFBlenderbotSmallEncoderLayer(tf.keras.layers.Layer):
@ -307,6 +323,7 @@ class TFBlenderbotSmallEncoderLayer(tf.keras.layers.Layer):
self.fc1 = tf.keras.layers.Dense(config.encoder_ffn_dim, name="fc1")
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
self.config = config
def call(
self,
@ -348,6 +365,26 @@ class TFBlenderbotSmallEncoderLayer(tf.keras.layers.Layer):
return hidden_states, self_attn_weights
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "self_attn", None) is not None:
with tf.name_scope(self.self_attn.name):
self.self_attn.build(None)
if getattr(self, "self_attn_layer_norm", None) is not None:
with tf.name_scope(self.self_attn_layer_norm.name):
self.self_attn_layer_norm.build([None, None, self.embed_dim])
if getattr(self, "fc1", None) is not None:
with tf.name_scope(self.fc1.name):
self.fc1.build([None, None, self.embed_dim])
if getattr(self, "fc2", None) is not None:
with tf.name_scope(self.fc2.name):
self.fc2.build([None, None, self.config.encoder_ffn_dim])
if getattr(self, "final_layer_norm", None) is not None:
with tf.name_scope(self.final_layer_norm.name):
self.final_layer_norm.build([None, None, self.embed_dim])
# Copied from transformers.models.bart.modeling_tf_bart.TFBartDecoderLayer with Bart->BlenderbotSmall
class TFBlenderbotSmallDecoderLayer(tf.keras.layers.Layer):
@ -377,6 +414,7 @@ class TFBlenderbotSmallDecoderLayer(tf.keras.layers.Layer):
self.fc1 = tf.keras.layers.Dense(config.decoder_ffn_dim, name="fc1")
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
self.config = config
def call(
self,
@ -458,6 +496,32 @@ class TFBlenderbotSmallDecoderLayer(tf.keras.layers.Layer):
present_key_value,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "self_attn", None) is not None:
with tf.name_scope(self.self_attn.name):
self.self_attn.build(None)
if getattr(self, "self_attn_layer_norm", None) is not None:
with tf.name_scope(self.self_attn_layer_norm.name):
self.self_attn_layer_norm.build([None, None, self.embed_dim])
if getattr(self, "encoder_attn", None) is not None:
with tf.name_scope(self.encoder_attn.name):
self.encoder_attn.build(None)
if getattr(self, "encoder_attn_layer_norm", None) is not None:
with tf.name_scope(self.encoder_attn_layer_norm.name):
self.encoder_attn_layer_norm.build([None, None, self.embed_dim])
if getattr(self, "fc1", None) is not None:
with tf.name_scope(self.fc1.name):
self.fc1.build([None, None, self.embed_dim])
if getattr(self, "fc2", None) is not None:
with tf.name_scope(self.fc2.name):
self.fc2.build([None, None, self.config.decoder_ffn_dim])
if getattr(self, "final_layer_norm", None) is not None:
with tf.name_scope(self.final_layer_norm.name):
self.final_layer_norm.build([None, None, self.embed_dim])
class TFBlenderbotSmallPreTrainedModel(TFPreTrainedModel):
config_class = BlenderbotSmallConfig
@ -646,6 +710,7 @@ class TFBlenderbotSmallEncoder(tf.keras.layers.Layer):
)
self.layers = [TFBlenderbotSmallEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)]
self.layernorm_embedding = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding")
self.embed_dim = config.d_model
def get_embed_tokens(self):
return self.embed_tokens
@ -717,16 +782,8 @@ class TFBlenderbotSmallEncoder(tf.keras.layers.Layer):
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None:
# if `self.embed_tokens.load_weight_prefix` is set, runs the embedding operation with the correct name
# scope, so that its weights are registered with the desired name for loading/storing. When `tf.name_scope`
# is used with a name ending in `/`, that name replaces the current name scope.
# (embeddings with tf.name_scope: self.embed_tokens.load_weight_prefix/self.embed_tokens.name/embeddings:0)
context = []
if hasattr(self.embed_tokens, "load_weight_prefix"):
context.append(tf.name_scope(self.embed_tokens.load_weight_prefix + "/"))
with ContextManagers(context):
check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim)
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim)
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
embed_pos = self.embed_positions(input_shape)
hidden_states = inputs_embeds + embed_pos
@ -781,6 +838,21 @@ class TFBlenderbotSmallEncoder(tf.keras.layers.Layer):
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embed_positions", None) is not None:
with tf.name_scope(self.embed_positions.name):
self.embed_positions.build(None)
if getattr(self, "layernorm_embedding", None) is not None:
with tf.name_scope(self.layernorm_embedding.name):
self.layernorm_embedding.build([None, None, self.embed_dim])
if getattr(self, "layers", None) is not None:
for layer in self.layers:
with tf.name_scope(layer.name):
layer.build(None)
@keras_serializable
class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
@ -917,16 +989,8 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
past_key_values_length = shape_list(past_key_values[0][0])[2] if past_key_values is not None else 0
if inputs_embeds is None:
# if `self.embed_tokens.load_weight_prefix` is set, runs the embedding operation with the correct name
# scope, so that its weights are registered with the desired name for loading/storing. When `tf.name_scope`
# is used with a name ending in `/`, that name replaces the current name scope.
# (embeddings with tf.name_scope: self.embed_tokens.load_weight_prefix/self.embed_tokens.name/embeddings:0)
context = []
if hasattr(self.embed_tokens, "load_weight_prefix"):
context.append(tf.name_scope(self.embed_tokens.load_weight_prefix + "/"))
with ContextManagers(context):
check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim)
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim)
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
if input_shape[-1] > 1:
@ -1014,6 +1078,21 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
cross_attentions=all_cross_attns,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embed_positions", None) is not None:
with tf.name_scope(self.embed_positions.name):
self.embed_positions.build(None)
if getattr(self, "layernorm_embedding", None) is not None:
with tf.name_scope(self.layernorm_embedding.name):
self.layernorm_embedding.build([None, None, self.config.d_model])
if getattr(self, "layers", None) is not None:
for layer in self.layers:
with tf.name_scope(layer.name):
layer.build(None)
@keras_serializable
class TFBlenderbotSmallMainLayer(tf.keras.layers.Layer):
@ -1122,6 +1201,22 @@ class TFBlenderbotSmallMainLayer(tf.keras.layers.Layer):
encoder_attentions=encoder_outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
# The shared/tied weights expect to be in the model base namespace
# Adding "/" to the end (not the start!) of a tf.name_scope puts it in the root namespace rather than
# the current one.
with tf.name_scope(self.shared.load_weight_prefix + "/" + self.shared.name + "/"):
self.shared.build(None)
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
if getattr(self, "decoder", None) is not None:
with tf.name_scope(self.decoder.name):
self.decoder.build(None)
@add_start_docstrings(
"The bare BLENDERBOT_SMALL Model outputting raw hidden-states without any specific head on top.",
@ -1209,6 +1304,14 @@ class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel):
encoder_attentions=enc_attns,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "model", None) is not None:
with tf.name_scope(self.model.name):
self.model.build(None)
# Copied from transformers.models.bart.modeling_tf_bart.BiasLayer
class BiasLayer(tf.keras.layers.Layer):
@ -1413,3 +1516,14 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
"cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
}
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "model", None) is not None:
with tf.name_scope(self.model.name):
self.model.build(None)
if getattr(self, "bias_layer", None) is not None:
with tf.name_scope(self.bias_layer.name):
self.bias_layer.build(None)

View File

@ -254,7 +254,7 @@ class TFBlipVisionEmbeddings(tf.keras.layers.Layer):
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches + 1
def build(self, input_shape):
def build(self, input_shape=None):
self.class_embedding = self.add_weight(
shape=(1, 1, self.embed_dim),
initializer=get_initializer(self.config.initializer_range),
@ -268,7 +268,13 @@ class TFBlipVisionEmbeddings(tf.keras.layers.Layer):
trainable=True,
name="position_embedding",
)
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "patch_embedding", None) is not None:
with tf.name_scope(self.patch_embedding.name):
self.patch_embedding.build([None, None, None, 3])
def call(self, pixel_values: tf.Tensor) -> tf.Tensor:
# Input is channels-first, we transpose. PyTorch transposes after the conv because PyTorch
@ -412,6 +418,20 @@ class TFBlipAttention(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dropout", None) is not None:
with tf.name_scope(self.dropout.name):
self.dropout.build(None)
if getattr(self, "qkv", None) is not None:
with tf.name_scope(self.qkv.name):
self.qkv.build([None, None, self.embed_dim])
if getattr(self, "projection", None) is not None:
with tf.name_scope(self.projection.name):
self.projection.build([None, None, self.embed_dim])
class TFBlipMLP(tf.keras.layers.Layer):
def __init__(self, config: BlipConfig, **kwargs):
@ -428,6 +448,7 @@ class TFBlipMLP(tf.keras.layers.Layer):
self.fc2 = tf.keras.layers.Dense(
units=config.hidden_size, kernel_initializer=get_initializer(in_proj_std), name="fc2"
)
self.config = config
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.fc1(inputs=hidden_states)
@ -435,6 +456,17 @@ class TFBlipMLP(tf.keras.layers.Layer):
hidden_states = self.fc2(inputs=hidden_states)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "fc1", None) is not None:
with tf.name_scope(self.fc1.name):
self.fc1.build([None, None, self.config.hidden_size])
if getattr(self, "fc2", None) is not None:
with tf.name_scope(self.fc2.name):
self.fc2.build([None, None, self.config.intermediate_size])
class TFBlipEncoderLayer(tf.keras.layers.Layer):
def __init__(self, config: BlipConfig, **kwargs):
@ -485,6 +517,23 @@ class TFBlipEncoderLayer(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "self_attn", None) is not None:
with tf.name_scope(self.self_attn.name):
self.self_attn.build(None)
if getattr(self, "layer_norm1", None) is not None:
with tf.name_scope(self.layer_norm1.name):
self.layer_norm1.build([None, None, self.embed_dim])
if getattr(self, "mlp", None) is not None:
with tf.name_scope(self.mlp.name):
self.mlp.build(None)
if getattr(self, "layer_norm2", None) is not None:
with tf.name_scope(self.layer_norm2.name):
self.layer_norm2.build([None, None, self.embed_dim])
class TFBlipPreTrainedModel(TFPreTrainedModel):
"""
@ -645,6 +694,15 @@ class TFBlipEncoder(tf.keras.layers.Layer):
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "layers", None) is not None:
for layer in self.layers:
with tf.name_scope(layer.name):
layer.build(None)
class TFBlipVisionModel(TFBlipPreTrainedModel):
main_input_name = "pixel_values"
@ -657,6 +715,7 @@ class TFBlipVisionModel(TFBlipPreTrainedModel):
self.embeddings = TFBlipVisionEmbeddings(config, name="embeddings")
self.encoder = TFBlipEncoder(config, name="encoder")
self.post_layernorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="post_layernorm")
self.embed_dim = config.hidden_size
def serving_output(self, output: TFBaseModelOutputWithPooling) -> TFBaseModelOutputWithPooling:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
@ -724,6 +783,20 @@ class TFBlipVisionModel(TFBlipPreTrainedModel):
def get_input_embeddings(self):
return self.embeddings
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embeddings", None) is not None:
with tf.name_scope(self.embeddings.name):
self.embeddings.build(None)
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
if getattr(self, "post_layernorm", None) is not None:
with tf.name_scope(self.post_layernorm.name):
self.post_layernorm.build([None, None, self.embed_dim])
class TFBlipMainLayer(tf.keras.layers.Layer):
config_class = BlipConfig
@ -775,7 +848,22 @@ class TFBlipMainLayer(tf.keras.layers.Layer):
initializer=tf.keras.initializers.Constant(self.config.logit_scale_init_value),
trainable=True,
)
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "text_model", None) is not None:
with tf.name_scope(self.text_model.name):
self.text_model.build(None)
if getattr(self, "vision_model", None) is not None:
with tf.name_scope(self.vision_model.name):
self.vision_model.build(None)
if getattr(self, "visual_projection", None) is not None:
with tf.name_scope(self.visual_projection.name):
self.visual_projection.build([None, None, self.vision_embed_dim])
if getattr(self, "text_projection", None) is not None:
with tf.name_scope(self.text_projection.name):
self.text_projection.build([None, None, self.text_embed_dim])
@unpack_inputs
def call(
@ -995,6 +1083,14 @@ class TFBlipModel(TFBlipPreTrainedModel):
return image_features
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "blip", None) is not None:
with tf.name_scope(self.blip.name):
self.blip.build(None)
@add_start_docstrings(
"""
@ -1168,6 +1264,17 @@ class TFBlipForConditionalGeneration(TFBlipPreTrainedModel):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "vision_model", None) is not None:
with tf.name_scope(self.vision_model.name):
self.vision_model.build(None)
if getattr(self, "text_decoder", None) is not None:
with tf.name_scope(self.text_decoder.name):
self.text_decoder.build(None)
@add_start_docstrings(
"""
@ -1409,6 +1516,20 @@ class TFBlipForQuestionAnswering(TFBlipPreTrainedModel):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "vision_model", None) is not None:
with tf.name_scope(self.vision_model.name):
self.vision_model.build(None)
if getattr(self, "text_encoder", None) is not None:
with tf.name_scope(self.text_encoder.name):
self.text_encoder.build(None)
if getattr(self, "text_decoder", None) is not None:
with tf.name_scope(self.text_decoder.name):
self.text_decoder.build(None)
@add_start_docstrings(
"""
@ -1457,6 +1578,7 @@ class TFBlipForImageTextRetrieval(TFBlipPreTrainedModel):
if not hasattr(config, "decoder_start_token_id")
else config.decoder_start_token_id
)
self.config = config
def get_input_embeddings(self) -> tf.keras.layers.Layer:
return self.vision_model.embeddings.patch_embedding
@ -1558,3 +1680,23 @@ class TFBlipForImageTextRetrieval(TFBlipPreTrainedModel):
attentions=vision_outputs.attentions,
question_embeds=question_embeds,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "vision_model", None) is not None:
with tf.name_scope(self.vision_model.name):
self.vision_model.build(None)
if getattr(self, "text_encoder", None) is not None:
with tf.name_scope(self.text_encoder.name):
self.text_encoder.build(None)
if getattr(self, "vision_proj", None) is not None:
with tf.name_scope(self.vision_proj.name):
self.vision_proj.build([None, None, self.config.vision_config.hidden_size])
if getattr(self, "text_proj", None) is not None:
with tf.name_scope(self.text_proj.name):
self.text_proj.build([None, None, self.config.text_config.hidden_size])
if getattr(self, "itm_head", None) is not None:
with tf.name_scope(self.itm_head.name):
self.itm_head.build([None, None, self.config.text_config.hidden_size])

View File

@ -127,6 +127,23 @@ class TFBlipTextEmbeddings(tf.keras.layers.Layer):
embeddings = self.dropout(embeddings, training=training)
return embeddings
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "word_embeddings", None) is not None:
with tf.name_scope(self.word_embeddings.name):
self.word_embeddings.build(None)
if getattr(self, "position_embeddings", None) is not None:
with tf.name_scope(self.position_embeddings.name):
self.position_embeddings.build(None)
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
if getattr(self, "dropout", None) is not None:
with tf.name_scope(self.dropout.name):
self.dropout.build(None)
# Adapted from https://github.com/salesforce/BLIP/blob/main/models/med.py#L97
class TFBlipTextSelfAttention(tf.keras.layers.Layer):
@ -160,6 +177,7 @@ class TFBlipTextSelfAttention(tf.keras.layers.Layer):
self.distance_embedding = tf.keras.layers.Embedding(
2 * config.max_position_embeddings - 1, self.attention_head_size
)
self.is_cross_attention = is_cross_attention
def transpose_for_scores(self, x):
new_x_shape = tf.concat(
@ -250,6 +268,28 @@ class TFBlipTextSelfAttention(tf.keras.layers.Layer):
outputs = outputs + (past_key_value,)
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "query", None) is not None:
with tf.name_scope(self.query.name):
self.query.build([None, None, self.config.hidden_size])
if self.is_cross_attention:
if getattr(self, "key", None) is not None:
with tf.name_scope(self.key.name):
self.key.build([None, None, self.config.encoder_hidden_size])
if getattr(self, "value", None) is not None:
with tf.name_scope(self.value.name):
self.value.build([None, None, self.config.encoder_hidden_size])
else:
if getattr(self, "key", None) is not None:
with tf.name_scope(self.key.name):
self.key.build([None, None, self.config.hidden_size])
if getattr(self, "value", None) is not None:
with tf.name_scope(self.value.name):
self.value.build([None, None, self.config.hidden_size])
class TFBlipTextSelfOutput(tf.keras.layers.Layer):
def __init__(self, config: BlipTextConfig, **kwargs):
@ -260,6 +300,7 @@ class TFBlipTextSelfOutput(tf.keras.layers.Layer):
)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
self.config = config
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: Optional[bool] = None) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -268,6 +309,17 @@ class TFBlipTextSelfOutput(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
# Adapted from https://github.com/salesforce/BLIP/blob/main/models/med.py#242
class TFBlipTextAttention(tf.keras.layers.Layer):
@ -302,6 +354,17 @@ class TFBlipTextAttention(tf.keras.layers.Layer):
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "self", None) is not None:
with tf.name_scope(self.self.name):
self.self.build(None)
if getattr(self, "self_output", None) is not None:
with tf.name_scope(self.self_output.name):
self.self_output.build(None)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertIntermediate with Bert->BlipText
class TFBlipTextIntermediate(tf.keras.layers.Layer):
@ -316,6 +379,7 @@ class TFBlipTextIntermediate(tf.keras.layers.Layer):
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
else:
self.intermediate_act_fn = config.hidden_act
self.config = config
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -323,6 +387,14 @@ class TFBlipTextIntermediate(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
class TFBlipTextOutput(tf.keras.layers.Layer):
def __init__(self, config: BlipTextConfig, **kwargs):
@ -333,6 +405,7 @@ class TFBlipTextOutput(tf.keras.layers.Layer):
)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
self.config = config
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -341,6 +414,17 @@ class TFBlipTextOutput(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.intermediate_size])
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
class TFBlipTextLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
@ -400,6 +484,23 @@ class TFBlipTextLayer(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "attention", None) is not None:
with tf.name_scope(self.attention.name):
self.attention.build(None)
if getattr(self, "intermediate", None) is not None:
with tf.name_scope(self.intermediate.name):
self.intermediate.build(None)
if getattr(self, "self_output", None) is not None:
with tf.name_scope(self.self_output.name):
self.self_output.build(None)
if getattr(self, "crossattention", None) is not None:
with tf.name_scope(self.crossattention.name):
self.crossattention.build(None)
# Adapted from https://github.com/salesforce/BLIP/blob/main/models/med.py#L386
@keras_serializable
@ -481,6 +582,15 @@ class TFBlipTextEncoder(tf.keras.layers.Layer):
cross_attentions=all_cross_attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "layer", None) is not None:
for layer in self.layer:
with tf.name_scope(layer.name):
layer.build(None)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertPooler with Bert->BlipText
class TFBlipTextPooler(tf.keras.layers.Layer):
@ -493,6 +603,7 @@ class TFBlipTextPooler(tf.keras.layers.Layer):
activation="tanh",
name="dense",
)
self.config = config
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
# We "pool" the model by simply taking the hidden state corresponding
@ -502,6 +613,14 @@ class TFBlipTextPooler(tf.keras.layers.Layer):
return pooled_output
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
# Copied from transformers.models.bert.modeling_tf_bert.TFBertPredictionHeadTransform with Bert->BlipText
class TFBlipTextPredictionHeadTransform(tf.keras.layers.Layer):
@ -520,6 +639,7 @@ class TFBlipTextPredictionHeadTransform(tf.keras.layers.Layer):
self.transform_act_fn = config.hidden_act
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.config = config
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -528,6 +648,17 @@ class TFBlipTextPredictionHeadTransform(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
class TFBlipTextLMPredictionHead(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
@ -546,7 +677,16 @@ class TFBlipTextLMPredictionHead(tf.keras.layers.Layer):
def build(self, input_shape=None):
self.bias = self.add_weight(name="bias", shape=(self.config.vocab_size,), initializer="zeros", trainable=True)
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "transform", None) is not None:
with tf.name_scope(self.transform.name):
self.transform.build(None)
if getattr(self, "decoder", None) is not None:
with tf.name_scope(self.decoder.name):
self.decoder.build([None, None, self.config.hidden_size])
def call(self, hidden_states):
hidden_states = self.transform(hidden_states)
@ -563,6 +703,14 @@ class TFBlipTextOnlyMLMHead(tf.keras.layers.Layer):
prediction_scores = self.predictions(sequence_output)
return prediction_scores
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "predictions", None) is not None:
with tf.name_scope(self.predictions.name):
self.predictions.build(None)
# Adapted from https://github.com/salesforce/BLIP/blob/main/models/med.py#L548
class TFBlipTextPreTrainedModel(TFPreTrainedModel):
@ -802,6 +950,20 @@ class TFBlipTextModel(TFBlipTextPreTrainedModel):
cross_attentions=encoder_outputs.cross_attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embeddings", None) is not None:
with tf.name_scope(self.embeddings.name):
self.embeddings.build(None)
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
if getattr(self, "pooler", None) is not None:
with tf.name_scope(self.pooler.name):
self.pooler.build(None)
# Adapted from https://github.com/salesforce/BLIP/blob/main/models/med.py#L811
class TFBlipTextLMHeadModel(TFBlipTextPreTrainedModel):
@ -942,3 +1104,14 @@ class TFBlipTextLMHeadModel(TFBlipTextPreTrainedModel):
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "bert", None) is not None:
with tf.name_scope(self.bert.name):
self.bert.build(None)
if getattr(self, "cls", None) is not None:
with tf.name_scope(self.cls.name):
self.cls.build(None)

View File

@ -184,7 +184,7 @@ class TFCamembertEmbeddings(tf.keras.layers.Layer):
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
def build(self, input_shape: tf.TensorShape):
def build(self, input_shape=None):
with tf.name_scope("word_embeddings"):
self.weight = self.add_weight(
name="weight",
@ -206,7 +206,12 @@ class TFCamembertEmbeddings(tf.keras.layers.Layer):
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
def create_position_ids_from_input_ids(self, input_ids, past_key_values_length=0):
"""
@ -279,6 +284,7 @@ class TFCamembertPooler(tf.keras.layers.Layer):
activation="tanh",
name="dense",
)
self.config = config
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
# We "pool" the model by simply taking the hidden state corresponding
@ -288,6 +294,14 @@ class TFCamembertPooler(tf.keras.layers.Layer):
return pooled_output
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfAttention with Bert->Camembert
class TFCamembertSelfAttention(tf.keras.layers.Layer):
@ -317,6 +331,7 @@ class TFCamembertSelfAttention(tf.keras.layers.Layer):
self.dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
self.is_decoder = config.is_decoder
self.config = config
def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
@ -406,6 +421,20 @@ class TFCamembertSelfAttention(tf.keras.layers.Layer):
outputs = outputs + (past_key_value,)
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "query", None) is not None:
with tf.name_scope(self.query.name):
self.query.build([None, None, self.config.hidden_size])
if getattr(self, "key", None) is not None:
with tf.name_scope(self.key.name):
self.key.build([None, None, self.config.hidden_size])
if getattr(self, "value", None) is not None:
with tf.name_scope(self.value.name):
self.value.build([None, None, self.config.hidden_size])
# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfOutput with Bert->Camembert
class TFCamembertSelfOutput(tf.keras.layers.Layer):
@ -417,6 +446,7 @@ class TFCamembertSelfOutput(tf.keras.layers.Layer):
)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
self.config = config
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -425,6 +455,17 @@ class TFCamembertSelfOutput(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
# Copied from transformers.models.bert.modeling_tf_bert.TFBertAttention with Bert->Camembert
class TFCamembertAttention(tf.keras.layers.Layer):
@ -466,6 +507,17 @@ class TFCamembertAttention(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "self_attention", None) is not None:
with tf.name_scope(self.self_attention.name):
self.self_attention.build(None)
if getattr(self, "dense_output", None) is not None:
with tf.name_scope(self.dense_output.name):
self.dense_output.build(None)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertIntermediate with Bert->Camembert
class TFCamembertIntermediate(tf.keras.layers.Layer):
@ -480,6 +532,7 @@ class TFCamembertIntermediate(tf.keras.layers.Layer):
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
else:
self.intermediate_act_fn = config.hidden_act
self.config = config
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -487,6 +540,14 @@ class TFCamembertIntermediate(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
# Copied from transformers.models.bert.modeling_tf_bert.TFBertOutput with Bert->Camembert
class TFCamembertOutput(tf.keras.layers.Layer):
@ -498,6 +559,7 @@ class TFCamembertOutput(tf.keras.layers.Layer):
)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
self.config = config
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -506,6 +568,17 @@ class TFCamembertOutput(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.intermediate_size])
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLayer with Bert->Camembert
class TFCamembertLayer(tf.keras.layers.Layer):
@ -593,6 +666,23 @@ class TFCamembertLayer(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "attention", None) is not None:
with tf.name_scope(self.attention.name):
self.attention.build(None)
if getattr(self, "intermediate", None) is not None:
with tf.name_scope(self.intermediate.name):
self.intermediate.build(None)
if getattr(self, "bert_output", None) is not None:
with tf.name_scope(self.bert_output.name):
self.bert_output.build(None)
if getattr(self, "crossattention", None) is not None:
with tf.name_scope(self.crossattention.name):
self.crossattention.build(None)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertEncoder with Bert->Camembert
class TFCamembertEncoder(tf.keras.layers.Layer):
@ -663,6 +753,15 @@ class TFCamembertEncoder(tf.keras.layers.Layer):
cross_attentions=all_cross_attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "layer", None) is not None:
for layer in self.layer:
with tf.name_scope(layer.name):
layer.build(None)
@keras_serializable
# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaMainLayer with Roberta->Camembert
@ -861,6 +960,20 @@ class TFCamembertMainLayer(tf.keras.layers.Layer):
cross_attentions=encoder_outputs.cross_attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
if getattr(self, "pooler", None) is not None:
with tf.name_scope(self.pooler.name):
self.pooler.build(None)
if getattr(self, "embeddings", None) is not None:
with tf.name_scope(self.embeddings.name):
self.embeddings.build(None)
class TFCamembertPreTrainedModel(TFPreTrainedModel):
"""
@ -945,6 +1058,14 @@ class TFCamembertModel(TFCamembertPreTrainedModel):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "roberta", None) is not None:
with tf.name_scope(self.roberta.name):
self.roberta.build(None)
# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaLMHead with Roberta->Camembert
class TFCamembertLMHead(tf.keras.layers.Layer):
@ -965,10 +1086,18 @@ class TFCamembertLMHead(tf.keras.layers.Layer):
# an output-only bias for each token.
self.decoder = input_embeddings
def build(self, input_shape):
def build(self, input_shape=None):
self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias")
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
if getattr(self, "layer_norm", None) is not None:
with tf.name_scope(self.layer_norm.name):
self.layer_norm.build([None, None, self.config.hidden_size])
def get_output_embeddings(self):
return self.decoder
@ -1080,6 +1209,17 @@ class TFCamembertForMaskedLM(TFCamembertPreTrainedModel, TFMaskedLanguageModelin
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "roberta", None) is not None:
with tf.name_scope(self.roberta.name):
self.roberta.build(None)
if getattr(self, "lm_head", None) is not None:
with tf.name_scope(self.lm_head.name):
self.lm_head.build(None)
# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaClassificationHead
class TFCamembertClassificationHead(tf.keras.layers.Layer):
@ -1100,6 +1240,7 @@ class TFCamembertClassificationHead(tf.keras.layers.Layer):
self.out_proj = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="out_proj"
)
self.config = config
def call(self, features, training=False):
x = features[:, 0, :] # take <s> token (equiv. to [CLS])
@ -1109,6 +1250,17 @@ class TFCamembertClassificationHead(tf.keras.layers.Layer):
x = self.out_proj(x)
return x
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
if getattr(self, "out_proj", None) is not None:
with tf.name_scope(self.out_proj.name):
self.out_proj.build([None, None, self.config.hidden_size])
@add_start_docstrings(
"""
@ -1186,6 +1338,17 @@ class TFCamembertForSequenceClassification(TFCamembertPreTrainedModel, TFSequenc
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "roberta", None) is not None:
with tf.name_scope(self.roberta.name):
self.roberta.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build(None)
@add_start_docstrings(
"""
@ -1212,6 +1375,7 @@ class TFCamembertForTokenClassification(TFCamembertPreTrainedModel, TFTokenClass
self.classifier = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(CAMEMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ -1270,6 +1434,17 @@ class TFCamembertForTokenClassification(TFCamembertPreTrainedModel, TFTokenClass
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "roberta", None) is not None:
with tf.name_scope(self.roberta.name):
self.roberta.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.hidden_size])
@add_start_docstrings(
"""
@ -1292,6 +1467,7 @@ class TFCamembertForMultipleChoice(TFCamembertPreTrainedModel, TFMultipleChoiceL
self.classifier = tf.keras.layers.Dense(
1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(
@ -1363,6 +1539,17 @@ class TFCamembertForMultipleChoice(TFCamembertPreTrainedModel, TFMultipleChoiceL
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "roberta", None) is not None:
with tf.name_scope(self.roberta.name):
self.roberta.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.hidden_size])
@add_start_docstrings(
"""
@ -1384,6 +1571,7 @@ class TFCamembertForQuestionAnswering(TFCamembertPreTrainedModel, TFQuestionAnsw
self.qa_outputs = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(CAMEMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ -1456,6 +1644,17 @@ class TFCamembertForQuestionAnswering(TFCamembertPreTrainedModel, TFQuestionAnsw
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "roberta", None) is not None:
with tf.name_scope(self.roberta.name):
self.roberta.build(None)
if getattr(self, "qa_outputs", None) is not None:
with tf.name_scope(self.qa_outputs.name):
self.qa_outputs.build([None, None, self.config.hidden_size])
@add_start_docstrings(
"""CamemBERT Model with a `language modeling` head on top for CLM fine-tuning.""", CAMEMBERT_START_DOCSTRING
@ -1581,3 +1780,14 @@ class TFCamembertForCausalLM(TFCamembertPreTrainedModel, TFCausalLanguageModelin
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "roberta", None) is not None:
with tf.name_scope(self.roberta.name):
self.roberta.build(None)
if getattr(self, "lm_head", None) is not None:
with tf.name_scope(self.lm_head.name):
self.lm_head.build(None)

View File

@ -169,7 +169,12 @@ class TFCLIPVisionEmbeddings(tf.keras.layers.Layer):
name="embeddings",
)
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "patch_embedding", None) is not None:
with tf.name_scope(self.patch_embedding.name):
self.patch_embedding.build([None, None, None, self.config.num_channels])
def call(self, pixel_values: tf.Tensor) -> tf.Tensor:
"""`pixel_values` is expected to be of NCHW format."""
@ -352,6 +357,23 @@ class TFCLIPAttention(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "q_proj", None) is not None:
with tf.name_scope(self.q_proj.name):
self.q_proj.build([None, None, self.embed_dim])
if getattr(self, "k_proj", None) is not None:
with tf.name_scope(self.k_proj.name):
self.k_proj.build([None, None, self.embed_dim])
if getattr(self, "v_proj", None) is not None:
with tf.name_scope(self.v_proj.name):
self.v_proj.build([None, None, self.embed_dim])
if getattr(self, "out_proj", None) is not None:
with tf.name_scope(self.out_proj.name):
self.out_proj.build([None, None, self.embed_dim])
class TFCLIPMLP(tf.keras.layers.Layer):
def __init__(self, config: CLIPConfig, **kwargs):
@ -369,6 +391,7 @@ class TFCLIPMLP(tf.keras.layers.Layer):
self.fc2 = tf.keras.layers.Dense(
units=config.hidden_size, kernel_initializer=get_initializer(in_proj_std), name="fc2"
)
self.config = config
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.fc1(inputs=hidden_states)
@ -376,6 +399,17 @@ class TFCLIPMLP(tf.keras.layers.Layer):
hidden_states = self.fc2(inputs=hidden_states)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "fc1", None) is not None:
with tf.name_scope(self.fc1.name):
self.fc1.build([None, None, self.config.hidden_size])
if getattr(self, "fc2", None) is not None:
with tf.name_scope(self.fc2.name):
self.fc2.build([None, None, self.config.intermediate_size])
class TFCLIPEncoderLayer(tf.keras.layers.Layer):
def __init__(self, config: CLIPConfig, **kwargs):
@ -428,6 +462,23 @@ class TFCLIPEncoderLayer(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "self_attn", None) is not None:
with tf.name_scope(self.self_attn.name):
self.self_attn.build(None)
if getattr(self, "layer_norm1", None) is not None:
with tf.name_scope(self.layer_norm1.name):
self.layer_norm1.build([None, None, self.embed_dim])
if getattr(self, "mlp", None) is not None:
with tf.name_scope(self.mlp.name):
self.mlp.build(None)
if getattr(self, "layer_norm2", None) is not None:
with tf.name_scope(self.layer_norm2.name):
self.layer_norm2.build([None, None, self.embed_dim])
class TFCLIPEncoder(tf.keras.layers.Layer):
"""
@ -483,6 +534,15 @@ class TFCLIPEncoder(tf.keras.layers.Layer):
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "layers", None) is not None:
for layer in self.layers:
with tf.name_scope(layer.name):
layer.build(None)
class TFCLIPTextTransformer(tf.keras.layers.Layer):
def __init__(self, config: CLIPTextConfig, **kwargs):
@ -496,6 +556,7 @@ class TFCLIPTextTransformer(tf.keras.layers.Layer):
# For `pooled_output` computation
self.eos_token_id = config.eos_token_id
self.embed_dim = config.hidden_size
def call(
self,
@ -586,6 +647,20 @@ class TFCLIPTextTransformer(tf.keras.layers.Layer):
return tf.broadcast_to(input=to_mask, shape=(batch_size, 1, seq_length, seq_length))
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embeddings", None) is not None:
with tf.name_scope(self.embeddings.name):
self.embeddings.build(None)
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
if getattr(self, "final_layer_norm", None) is not None:
with tf.name_scope(self.final_layer_norm.name):
self.final_layer_norm.build([None, None, self.embed_dim])
@keras_serializable
class TFCLIPTextMainLayer(tf.keras.layers.Layer):
@ -634,6 +709,14 @@ class TFCLIPTextMainLayer(tf.keras.layers.Layer):
return text_model_outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "text_model", None) is not None:
with tf.name_scope(self.text_model.name):
self.text_model.build(None)
class TFCLIPVisionTransformer(tf.keras.layers.Layer):
def __init__(self, config: CLIPVisionConfig, **kwargs):
@ -643,6 +726,7 @@ class TFCLIPVisionTransformer(tf.keras.layers.Layer):
self.pre_layernorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="pre_layrnorm")
self.encoder = TFCLIPEncoder(config, name="encoder")
self.post_layernorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="post_layernorm")
self.embed_dim = config.hidden_size
def call(
self,
@ -679,6 +763,23 @@ class TFCLIPVisionTransformer(tf.keras.layers.Layer):
attentions=encoder_outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embeddings", None) is not None:
with tf.name_scope(self.embeddings.name):
self.embeddings.build(None)
if getattr(self, "pre_layernorm", None) is not None:
with tf.name_scope(self.pre_layernorm.name):
self.pre_layernorm.build([None, None, self.embed_dim])
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
if getattr(self, "post_layernorm", None) is not None:
with tf.name_scope(self.post_layernorm.name):
self.post_layernorm.build([None, self.embed_dim])
@keras_serializable
class TFCLIPVisionMainLayer(tf.keras.layers.Layer):
@ -714,6 +815,14 @@ class TFCLIPVisionMainLayer(tf.keras.layers.Layer):
return vision_model_outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "vision_model", None) is not None:
with tf.name_scope(self.vision_model.name):
self.vision_model.build(None)
@keras_serializable
class TFCLIPMainLayer(tf.keras.layers.Layer):
@ -757,6 +866,8 @@ class TFCLIPMainLayer(tf.keras.layers.Layer):
use_bias=False,
name="text_projection",
)
self.text_embed_dim = text_config.hidden_size
self.vision_embed_dim = vision_config.hidden_size
def build(self, input_shape: tf.TensorShape = None):
self.logit_scale = self.add_weight(
@ -766,7 +877,21 @@ class TFCLIPMainLayer(tf.keras.layers.Layer):
name="logit_scale",
)
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "text_model", None) is not None:
with tf.name_scope(self.text_model.name):
self.text_model.build(None)
if getattr(self, "vision_model", None) is not None:
with tf.name_scope(self.vision_model.name):
self.vision_model.build(None)
if getattr(self, "visual_projection", None) is not None:
with tf.name_scope(self.visual_projection.name):
self.visual_projection.build([None, None, self.vision_embed_dim])
if getattr(self, "text_projection", None) is not None:
with tf.name_scope(self.text_projection.name):
self.text_projection.build([None, None, self.text_embed_dim])
@unpack_inputs
def get_text_features(
@ -1108,6 +1233,14 @@ class TFCLIPTextModel(TFCLIPPreTrainedModel):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "clip", None) is not None:
with tf.name_scope(self.clip.name):
self.clip.build(None)
class TFCLIPVisionModel(TFCLIPPreTrainedModel):
config_class = CLIPVisionConfig
@ -1162,6 +1295,14 @@ class TFCLIPVisionModel(TFCLIPPreTrainedModel):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "clip", None) is not None:
with tf.name_scope(self.clip.name):
self.clip.build(None)
@add_start_docstrings(CLIP_START_DOCSTRING)
class TFCLIPModel(TFCLIPPreTrainedModel):
@ -1313,3 +1454,11 @@ class TFCLIPModel(TFCLIPPreTrainedModel):
# TensorFlow cannot trace through nested dataclasses. Reference:
# https://github.com/huggingface/transformers/pull/16886
return output
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "clip", None) is not None:
with tf.name_scope(self.clip.name):
self.clip.build(None)

View File

@ -81,7 +81,7 @@ class TFConvBertEmbeddings(tf.keras.layers.Layer):
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
def build(self, input_shape: tf.TensorShape):
def build(self, input_shape=None):
with tf.name_scope("word_embeddings"):
self.weight = self.add_weight(
name="weight",
@ -103,7 +103,12 @@ class TFConvBertEmbeddings(tf.keras.layers.Layer):
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.embedding_size])
# Copied from transformers.models.bert.modeling_tf_bert.TFBertEmbeddings.call
def call(
@ -208,6 +213,7 @@ class TFConvBertSelfAttention(tf.keras.layers.Layer):
)
self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)
self.config = config
def transpose_for_scores(self, x, batch_size):
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
@ -297,6 +303,29 @@ class TFConvBertSelfAttention(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "query", None) is not None:
with tf.name_scope(self.query.name):
self.query.build([None, None, self.config.hidden_size])
if getattr(self, "key", None) is not None:
with tf.name_scope(self.key.name):
self.key.build([None, None, self.config.hidden_size])
if getattr(self, "value", None) is not None:
with tf.name_scope(self.value.name):
self.value.build([None, None, self.config.hidden_size])
if getattr(self, "key_conv_attn_layer", None) is not None:
with tf.name_scope(self.key_conv_attn_layer.name):
self.key_conv_attn_layer.build([None, None, self.config.hidden_size])
if getattr(self, "conv_kernel_layer", None) is not None:
with tf.name_scope(self.conv_kernel_layer.name):
self.conv_kernel_layer.build([None, None, self.all_head_size])
if getattr(self, "conv_out_layer", None) is not None:
with tf.name_scope(self.conv_out_layer.name):
self.conv_out_layer.build([None, None, self.config.hidden_size])
class TFConvBertSelfOutput(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
@ -307,6 +336,7 @@ class TFConvBertSelfOutput(tf.keras.layers.Layer):
)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
self.config = config
def call(self, hidden_states, input_tensor, training=False):
hidden_states = self.dense(hidden_states)
@ -315,6 +345,17 @@ class TFConvBertSelfOutput(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
class TFConvBertAttention(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
@ -335,6 +376,17 @@ class TFConvBertAttention(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "self_attention", None) is not None:
with tf.name_scope(self.self_attention.name):
self.self_attention.build(None)
if getattr(self, "dense_output", None) is not None:
with tf.name_scope(self.dense_output.name):
self.dense_output.build(None)
class GroupedLinearLayer(tf.keras.layers.Layer):
def __init__(self, input_size, output_size, num_groups, kernel_initializer, **kwargs):
@ -389,6 +441,7 @@ class TFConvBertIntermediate(tf.keras.layers.Layer):
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
else:
self.intermediate_act_fn = config.hidden_act
self.config = config
def call(self, hidden_states):
hidden_states = self.dense(hidden_states)
@ -396,6 +449,14 @@ class TFConvBertIntermediate(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
class TFConvBertOutput(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
@ -415,6 +476,7 @@ class TFConvBertOutput(tf.keras.layers.Layer):
)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
self.config = config
def call(self, hidden_states, input_tensor, training=False):
hidden_states = self.dense(hidden_states)
@ -423,6 +485,17 @@ class TFConvBertOutput(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.intermediate_size])
class TFConvBertLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
@ -443,6 +516,20 @@ class TFConvBertLayer(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "attention", None) is not None:
with tf.name_scope(self.attention.name):
self.attention.build(None)
if getattr(self, "intermediate", None) is not None:
with tf.name_scope(self.intermediate.name):
self.intermediate.build(None)
if getattr(self, "bert_output", None) is not None:
with tf.name_scope(self.bert_output.name):
self.bert_output.build(None)
class TFConvBertEncoder(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
@ -486,6 +573,15 @@ class TFConvBertEncoder(tf.keras.layers.Layer):
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "layer", None) is not None:
for layer in self.layer:
with tf.name_scope(layer.name):
layer.build(None)
class TFConvBertPredictionHeadTransform(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
@ -501,6 +597,7 @@ class TFConvBertPredictionHeadTransform(tf.keras.layers.Layer):
self.transform_act_fn = config.hidden_act
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.config = config
def call(self, hidden_states):
hidden_states = self.dense(hidden_states)
@ -509,6 +606,17 @@ class TFConvBertPredictionHeadTransform(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
@keras_serializable
class TFConvBertMainLayer(tf.keras.layers.Layer):
@ -616,6 +724,20 @@ class TFConvBertMainLayer(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embeddings", None) is not None:
with tf.name_scope(self.embeddings.name):
self.embeddings.build(None)
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
if getattr(self, "embeddings_project", None) is not None:
with tf.name_scope(self.embeddings_project.name):
self.embeddings_project.build([None, None, self.config.embedding_size])
class TFConvBertPreTrainedModel(TFPreTrainedModel):
"""
@ -770,6 +892,14 @@ class TFConvBertModel(TFConvBertPreTrainedModel):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "convbert", None) is not None:
with tf.name_scope(self.convbert.name):
self.convbert.build(None)
class TFConvBertMaskedLMHead(tf.keras.layers.Layer):
def __init__(self, config, input_embeddings, **kwargs):
@ -814,6 +944,7 @@ class TFConvBertGeneratorPredictions(tf.keras.layers.Layer):
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dense = tf.keras.layers.Dense(config.embedding_size, name="dense")
self.config = config
def call(self, generator_hidden_states, training=False):
hidden_states = self.dense(generator_hidden_states)
@ -822,6 +953,17 @@ class TFConvBertGeneratorPredictions(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.embedding_size])
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
@add_start_docstrings("""ConvBERT Model with a `language modeling` head on top.""", CONVBERT_START_DOCSTRING)
class TFConvBertForMaskedLM(TFConvBertPreTrainedModel, TFMaskedLanguageModelingLoss):
@ -901,6 +1043,20 @@ class TFConvBertForMaskedLM(TFConvBertPreTrainedModel, TFMaskedLanguageModelingL
attentions=generator_hidden_states.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "convbert", None) is not None:
with tf.name_scope(self.convbert.name):
self.convbert.build(None)
if getattr(self, "generator_predictions", None) is not None:
with tf.name_scope(self.generator_predictions.name):
self.generator_predictions.build(None)
if getattr(self, "generator_lm_head", None) is not None:
with tf.name_scope(self.generator_lm_head.name):
self.generator_lm_head.build(None)
class TFConvBertClassificationHead(tf.keras.layers.Layer):
"""Head for sentence-level classification tasks."""
@ -931,6 +1087,17 @@ class TFConvBertClassificationHead(tf.keras.layers.Layer):
return x
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
if getattr(self, "out_proj", None) is not None:
with tf.name_scope(self.out_proj.name):
self.out_proj.build([None, None, self.config.hidden_size])
@add_start_docstrings(
"""
@ -999,6 +1166,17 @@ class TFConvBertForSequenceClassification(TFConvBertPreTrainedModel, TFSequenceC
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "convbert", None) is not None:
with tf.name_scope(self.convbert.name):
self.convbert.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build(None)
@add_start_docstrings(
"""
@ -1018,6 +1196,7 @@ class TFConvBertForMultipleChoice(TFConvBertPreTrainedModel, TFMultipleChoiceLos
self.classifier = tf.keras.layers.Dense(
1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(
@ -1092,6 +1271,20 @@ class TFConvBertForMultipleChoice(TFConvBertPreTrainedModel, TFMultipleChoiceLos
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "convbert", None) is not None:
with tf.name_scope(self.convbert.name):
self.convbert.build(None)
if getattr(self, "sequence_summary", None) is not None:
with tf.name_scope(self.sequence_summary.name):
self.sequence_summary.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.hidden_size])
@add_start_docstrings(
"""
@ -1113,6 +1306,7 @@ class TFConvBertForTokenClassification(TFConvBertPreTrainedModel, TFTokenClassif
self.classifier = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(CONVBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ -1167,6 +1361,17 @@ class TFConvBertForTokenClassification(TFConvBertPreTrainedModel, TFTokenClassif
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "convbert", None) is not None:
with tf.name_scope(self.convbert.name):
self.convbert.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.hidden_size])
@add_start_docstrings(
"""
@ -1184,6 +1389,7 @@ class TFConvBertForQuestionAnswering(TFConvBertPreTrainedModel, TFQuestionAnswer
self.qa_outputs = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(CONVBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ -1252,3 +1458,14 @@ class TFConvBertForQuestionAnswering(TFConvBertPreTrainedModel, TFQuestionAnswer
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "convbert", None) is not None:
with tf.name_scope(self.convbert.name):
self.convbert.build(None)
if getattr(self, "qa_outputs", None) is not None:
with tf.name_scope(self.qa_outputs.name):
self.qa_outputs.build([None, None, self.config.hidden_size])

View File

@ -81,6 +81,7 @@ class TFConvNextEmbeddings(tf.keras.layers.Layer):
)
self.layernorm = tf.keras.layers.LayerNormalization(epsilon=1e-6, name="layernorm")
self.num_channels = config.num_channels
self.config = config
def call(self, pixel_values):
if isinstance(pixel_values, dict):
@ -101,6 +102,17 @@ class TFConvNextEmbeddings(tf.keras.layers.Layer):
embeddings = self.layernorm(embeddings)
return embeddings
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "patch_embeddings", None) is not None:
with tf.name_scope(self.patch_embeddings.name):
self.patch_embeddings.build([None, None, None, self.config.num_channels])
if getattr(self, "layernorm", None) is not None:
with tf.name_scope(self.layernorm.name):
self.layernorm.build([None, None, None, self.config.hidden_sizes[0]])
class TFConvNextLayer(tf.keras.layers.Layer):
"""This corresponds to the `Block` class in the original implementation.
@ -167,7 +179,25 @@ class TFConvNextLayer(tf.keras.layers.Layer):
if self.config.layer_scale_init_value > 0
else None
)
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "dwconv", None) is not None:
with tf.name_scope(self.dwconv.name):
self.dwconv.build([None, None, None, self.dim])
if getattr(self, "layernorm", None) is not None:
with tf.name_scope(self.layernorm.name):
self.layernorm.build([None, None, None, self.dim])
if getattr(self, "pwconv1", None) is not None:
with tf.name_scope(self.pwconv1.name):
self.pwconv1.build([None, None, self.dim])
if getattr(self, "pwconv2", None) is not None:
with tf.name_scope(self.pwconv2.name):
self.pwconv2.build([None, None, 4 * self.dim])
if getattr(self, "drop_path", None) is not None:
with tf.name_scope(self.drop_path.name):
self.drop_path.build(None)
def call(self, hidden_states, training=False):
input = hidden_states
@ -245,6 +275,9 @@ class TFConvNextStage(tf.keras.layers.Layer):
)
for j in range(depth)
]
self.in_channels = in_channels
self.out_channels = out_channels
self.stride = stride
def call(self, hidden_states):
for layer in self.downsampling_layer:
@ -253,6 +286,20 @@ class TFConvNextStage(tf.keras.layers.Layer):
hidden_states = layer(hidden_states)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "layers", None) is not None:
for layer in self.layers:
with tf.name_scope(layer.name):
layer.build(None)
if self.in_channels != self.out_channels or self.stride > 1:
with tf.name_scope(self.downsampling_layer[0].name):
self.downsampling_layer[0].build([None, None, None, self.in_channels])
with tf.name_scope(self.downsampling_layer[1].name):
self.downsampling_layer[1].build([None, None, None, self.in_channels])
class TFConvNextEncoder(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
@ -293,6 +340,11 @@ class TFConvNextEncoder(tf.keras.layers.Layer):
return TFBaseModelOutput(last_hidden_state=hidden_states, hidden_states=all_hidden_states)
def build(self, input_shape=None):
for stage in self.stages:
with tf.name_scope(stage.name):
stage.build(None)
@keras_serializable
class TFConvNextMainLayer(tf.keras.layers.Layer):
@ -353,6 +405,20 @@ class TFConvNextMainLayer(tf.keras.layers.Layer):
hidden_states=hidden_states if output_hidden_states else encoder_outputs.hidden_states,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embeddings", None) is not None:
with tf.name_scope(self.embeddings.name):
self.embeddings.build(None)
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
if getattr(self, "layernorm", None) is not None:
with tf.name_scope(self.layernorm.name):
self.layernorm.build([None, self.config.hidden_sizes[-1]])
class TFConvNextPreTrainedModel(TFPreTrainedModel):
"""
@ -485,6 +551,14 @@ class TFConvNextModel(TFConvNextPreTrainedModel):
hidden_states=outputs.hidden_states,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "convnext", None) is not None:
with tf.name_scope(self.convnext.name):
self.convnext.build(None)
@add_start_docstrings(
"""
@ -507,6 +581,7 @@ class TFConvNextForImageClassification(TFConvNextPreTrainedModel, TFSequenceClas
bias_initializer="zeros",
name="classifier",
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(CONVNEXT_INPUTS_DOCSTRING)
@ -577,3 +652,15 @@ class TFConvNextForImageClassification(TFConvNextPreTrainedModel, TFSequenceClas
logits=logits,
hidden_states=outputs.hidden_states,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "convnext", None) is not None:
with tf.name_scope(self.convnext.name):
self.convnext.build(None)
if getattr(self, "classifier", None) is not None:
if hasattr(self.classifier, "name"):
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.hidden_sizes[-1]])

View File

@ -133,6 +133,7 @@ class TFConvNextV2Embeddings(tf.keras.layers.Layer):
)
self.layernorm = tf.keras.layers.LayerNormalization(epsilon=1e-6, name="layernorm")
self.num_channels = config.num_channels
self.config = config
def call(self, pixel_values):
if isinstance(pixel_values, dict):
@ -153,6 +154,17 @@ class TFConvNextV2Embeddings(tf.keras.layers.Layer):
embeddings = self.layernorm(embeddings)
return embeddings
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "patch_embeddings", None) is not None:
with tf.name_scope(self.patch_embeddings.name):
self.patch_embeddings.build([None, None, None, self.config.num_channels])
if getattr(self, "layernorm", None) is not None:
with tf.name_scope(self.layernorm.name):
self.layernorm.build([None, None, None, self.config.hidden_sizes[0]])
class TFConvNextV2Layer(tf.keras.layers.Layer):
"""This corresponds to the `Block` class in the original implementation.
@ -223,6 +235,29 @@ class TFConvNextV2Layer(tf.keras.layers.Layer):
x = input + x
return x
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dwconv", None) is not None:
with tf.name_scope(self.dwconv.name):
self.dwconv.build([None, None, None, self.dim])
if getattr(self, "layernorm", None) is not None:
with tf.name_scope(self.layernorm.name):
self.layernorm.build([None, None, None, self.dim])
if getattr(self, "pwconv1", None) is not None:
with tf.name_scope(self.pwconv1.name):
self.pwconv1.build([None, None, self.dim])
if getattr(self, "grn", None) is not None:
with tf.name_scope(self.grn.name):
self.grn.build(None)
if getattr(self, "pwconv2", None) is not None:
with tf.name_scope(self.pwconv2.name):
self.pwconv2.build([None, None, 4 * self.dim])
if getattr(self, "drop_path", None) is not None:
with tf.name_scope(self.drop_path.name):
self.drop_path.build(None)
# Copied from transformers.models.convnext.modeling_tf_convnext.TFConvNextStage with ConvNext->ConvNextV2
class TFConvNextV2Stage(tf.keras.layers.Layer):
@ -286,6 +321,9 @@ class TFConvNextV2Stage(tf.keras.layers.Layer):
)
for j in range(depth)
]
self.in_channels = in_channels
self.out_channels = out_channels
self.stride = stride
def call(self, hidden_states):
for layer in self.downsampling_layer:
@ -294,6 +332,20 @@ class TFConvNextV2Stage(tf.keras.layers.Layer):
hidden_states = layer(hidden_states)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "layers", None) is not None:
for layer in self.layers:
with tf.name_scope(layer.name):
layer.build(None)
if self.in_channels != self.out_channels or self.stride > 1:
with tf.name_scope(self.downsampling_layer[0].name):
self.downsampling_layer[0].build([None, None, None, self.in_channels])
with tf.name_scope(self.downsampling_layer[1].name):
self.downsampling_layer[1].build([None, None, None, self.in_channels])
class TFConvNextV2Encoder(tf.keras.layers.Layer):
def __init__(self, config: ConvNextV2Config, **kwargs):
@ -339,6 +391,11 @@ class TFConvNextV2Encoder(tf.keras.layers.Layer):
return TFBaseModelOutputWithNoAttention(last_hidden_state=hidden_states, hidden_states=all_hidden_states)
def build(self, input_shape=None):
for stage in self.stages:
with tf.name_scope(stage.name):
stage.build(None)
@keras_serializable
class TFConvNextV2MainLayer(tf.keras.layers.Layer):
@ -401,6 +458,20 @@ class TFConvNextV2MainLayer(tf.keras.layers.Layer):
hidden_states=hidden_states if output_hidden_states else encoder_outputs.hidden_states,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embeddings", None) is not None:
with tf.name_scope(self.embeddings.name):
self.embeddings.build(None)
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
if getattr(self, "layernorm", None) is not None:
with tf.name_scope(self.layernorm.name):
self.layernorm.build([None, self.config.hidden_sizes[-1]])
class TFConvNextV2PreTrainedModel(TFPreTrainedModel):
"""
@ -519,6 +590,14 @@ class TFConvNextV2Model(TFConvNextV2PreTrainedModel):
hidden_states=outputs.hidden_states,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "convnextv2", None) is not None:
with tf.name_scope(self.convnextv2.name):
self.convnextv2.build(None)
@add_start_docstrings(
"""
@ -593,3 +672,14 @@ class TFConvNextV2ForImageClassification(TFConvNextV2PreTrainedModel, TFSequence
logits=logits,
hidden_states=outputs.hidden_states,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "convnextv2", None) is not None:
with tf.name_scope(self.convnextv2.name):
self.convnextv2.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.hidden_sizes[-1]])

View File

@ -142,6 +142,23 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "Wq", None) is not None:
with tf.name_scope(self.Wq.name):
self.Wq.build([None, None, self.d_model_size])
if getattr(self, "Wk", None) is not None:
with tf.name_scope(self.Wk.name):
self.Wk.build([None, None, self.d_model_size])
if getattr(self, "Wv", None) is not None:
with tf.name_scope(self.Wv.name):
self.Wv.build([None, None, self.d_model_size])
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.d_model_size])
class TFPointWiseFeedForwardLayer(tf.keras.layers.Layer):
def __init__(self, d_model_size, dff, **kwargs):
@ -149,6 +166,8 @@ class TFPointWiseFeedForwardLayer(tf.keras.layers.Layer):
self.dense_0 = tf.keras.layers.Dense(dff, activation="relu", name="0")
self.dense_2 = tf.keras.layers.Dense(d_model_size, name="2")
self.d_model_size = d_model_size
self.dff = dff
def call(self, inputs, trainable=False):
dense_0_output = self.dense_0(inputs)
@ -156,6 +175,17 @@ class TFPointWiseFeedForwardLayer(tf.keras.layers.Layer):
return dense_2_output
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense_0", None) is not None:
with tf.name_scope(self.dense_0.name):
self.dense_0.build([None, None, self.d_model_size])
if getattr(self, "dense_2", None) is not None:
with tf.name_scope(self.dense_2.name):
self.dense_2.build([None, None, self.dff])
class TFEncoderLayer(tf.keras.layers.Layer):
def __init__(
@ -175,6 +205,7 @@ class TFEncoderLayer(tf.keras.layers.Layer):
self.dropout1 = tf.keras.layers.Dropout(rate)
self.dropout2 = tf.keras.layers.Dropout(rate)
self.d_model_size = d_model_size
def call(self, x, mask, layer_past, attention_mask, head_mask, use_cache, output_attentions, training=False):
normed = self.layernorm1(x)
@ -202,6 +233,23 @@ class TFEncoderLayer(tf.keras.layers.Layer):
outputs = (out2,) + attn_outputs[1:]
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "multi_head_attention", None) is not None:
with tf.name_scope(self.multi_head_attention.name):
self.multi_head_attention.build(None)
if getattr(self, "ffn", None) is not None:
with tf.name_scope(self.ffn.name):
self.ffn.build(None)
if getattr(self, "layernorm1", None) is not None:
with tf.name_scope(self.layernorm1.name):
self.layernorm1.build([None, None, self.d_model_size])
if getattr(self, "layernorm2", None) is not None:
with tf.name_scope(self.layernorm2.name):
self.layernorm2.build([None, None, self.d_model_size])
@keras_serializable
class TFCTRLMainLayer(tf.keras.layers.Layer):
@ -396,6 +444,21 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
attentions=all_attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "w", None) is not None:
with tf.name_scope(self.w.name):
self.w.build(None)
if getattr(self, "layernorm", None) is not None:
with tf.name_scope(self.layernorm.name):
self.layernorm.build([None, None, self.config.n_embd])
if getattr(self, "h", None) is not None:
for layer in self.h:
with tf.name_scope(layer.name):
layer.build(None)
class TFCTRLPreTrainedModel(TFPreTrainedModel):
"""
@ -563,6 +626,14 @@ class TFCTRLModel(TFCTRLPreTrainedModel):
)
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "transformer", None) is not None:
with tf.name_scope(self.transformer.name):
self.transformer.build(None)
class TFCTRLBiasLayer(tf.keras.layers.Layer):
"""
@ -710,6 +781,17 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
attentions=transformer_outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "transformer", None) is not None:
with tf.name_scope(self.transformer.name):
self.transformer.build(None)
if getattr(self, "bias_layer", None) is not None:
with tf.name_scope(self.bias_layer.name):
self.bias_layer.build(None)
@add_start_docstrings(
"""
@ -737,6 +819,7 @@ class TFCTRLForSequenceClassification(TFCTRLPreTrainedModel, TFSequenceClassific
use_bias=False,
)
self.transformer = TFCTRLMainLayer(config, name="transformer")
self.config = config
def get_output_embeddings(self):
# Remove after transformers v4.32. Fix this model's `test_model_common_attributes` test too.
@ -836,3 +919,14 @@ class TFCTRLForSequenceClassification(TFCTRLPreTrainedModel, TFSequenceClassific
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.n_embd])
if getattr(self, "transformer", None) is not None:
with tf.name_scope(self.transformer.name):
self.transformer.build(None)

View File

@ -107,6 +107,7 @@ class TFCvtEmbeddings(tf.keras.layers.Layer):
self,
config: CvtConfig,
patch_size: int,
num_channels: int,
embed_dim: int,
stride: int,
padding: int,
@ -117,6 +118,7 @@ class TFCvtEmbeddings(tf.keras.layers.Layer):
self.convolution_embeddings = TFCvtConvEmbeddings(
config,
patch_size=patch_size,
num_channels=num_channels,
embed_dim=embed_dim,
stride=stride,
padding=padding,
@ -129,11 +131,28 @@ class TFCvtEmbeddings(tf.keras.layers.Layer):
hidden_state = self.dropout(hidden_state, training=training)
return hidden_state
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "convolution_embeddings", None) is not None:
with tf.name_scope(self.convolution_embeddings.name):
self.convolution_embeddings.build(None)
class TFCvtConvEmbeddings(tf.keras.layers.Layer):
"""Image to Convolution Embeddings. This convolutional operation aims to model local spatial contexts."""
def __init__(self, config: CvtConfig, patch_size: int, embed_dim: int, stride: int, padding: int, **kwargs):
def __init__(
self,
config: CvtConfig,
patch_size: int,
num_channels: int,
embed_dim: int,
stride: int,
padding: int,
**kwargs,
):
super().__init__(**kwargs)
self.padding = tf.keras.layers.ZeroPadding2D(padding=padding)
self.patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
@ -148,6 +167,8 @@ class TFCvtConvEmbeddings(tf.keras.layers.Layer):
)
# Using the same default epsilon as PyTorch
self.normalization = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="normalization")
self.num_channels = num_channels
self.embed_dim = embed_dim
def call(self, pixel_values: tf.Tensor) -> tf.Tensor:
if isinstance(pixel_values, dict):
@ -165,6 +186,17 @@ class TFCvtConvEmbeddings(tf.keras.layers.Layer):
pixel_values = tf.reshape(pixel_values, shape=(batch_size, height, width, num_channels))
return pixel_values
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "projection", None) is not None:
with tf.name_scope(self.projection.name):
self.projection.build([None, None, None, self.num_channels])
if getattr(self, "normalization", None) is not None:
with tf.name_scope(self.normalization.name):
self.normalization.build([None, None, self.embed_dim])
class TFCvtSelfAttentionConvProjection(tf.keras.layers.Layer):
"""Convolutional projection layer."""
@ -184,12 +216,24 @@ class TFCvtSelfAttentionConvProjection(tf.keras.layers.Layer):
)
# Using the same default epsilon as PyTorch, TF uses (1 - pytorch momentum)
self.normalization = tf.keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.9, name="normalization")
self.embed_dim = embed_dim
def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_state = self.convolution(self.padding(hidden_state))
hidden_state = self.normalization(hidden_state, training=training)
return hidden_state
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "convolution", None) is not None:
with tf.name_scope(self.convolution.name):
self.convolution.build([None, None, None, self.embed_dim])
if getattr(self, "normalization", None) is not None:
with tf.name_scope(self.normalization.name):
self.normalization.build([None, None, None, self.embed_dim])
class TFCvtSelfAttentionLinearProjection(tf.keras.layers.Layer):
"""Linear projection layer used to flatten tokens into 1D."""
@ -227,6 +271,14 @@ class TFCvtSelfAttentionProjection(tf.keras.layers.Layer):
hidden_state = self.linear_projection(hidden_state)
return hidden_state
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "convolution_projection", None) is not None:
with tf.name_scope(self.convolution_projection.name):
self.convolution_projection.build(None)
class TFCvtSelfAttention(tf.keras.layers.Layer):
"""
@ -348,6 +400,29 @@ class TFCvtSelfAttention(tf.keras.layers.Layer):
context = tf.reshape(context, (batch_size, hidden_size, self.num_heads * head_dim))
return context
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "convolution_projection_query", None) is not None:
with tf.name_scope(self.convolution_projection_query.name):
self.convolution_projection_query.build(None)
if getattr(self, "convolution_projection_key", None) is not None:
with tf.name_scope(self.convolution_projection_key.name):
self.convolution_projection_key.build(None)
if getattr(self, "convolution_projection_value", None) is not None:
with tf.name_scope(self.convolution_projection_value.name):
self.convolution_projection_value.build(None)
if getattr(self, "projection_query", None) is not None:
with tf.name_scope(self.projection_query.name):
self.projection_query.build([None, None, self.embed_dim])
if getattr(self, "projection_key", None) is not None:
with tf.name_scope(self.projection_key.name):
self.projection_key.build([None, None, self.embed_dim])
if getattr(self, "projection_value", None) is not None:
with tf.name_scope(self.projection_value.name):
self.projection_value.build([None, None, self.embed_dim])
class TFCvtSelfOutput(tf.keras.layers.Layer):
"""Output of the Attention layer ."""
@ -358,12 +433,21 @@ class TFCvtSelfOutput(tf.keras.layers.Layer):
units=embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
self.dropout = tf.keras.layers.Dropout(drop_rate)
self.embed_dim = embed_dim
def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_state = self.dense(inputs=hidden_state)
hidden_state = self.dropout(inputs=hidden_state, training=training)
return hidden_state
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.embed_dim])
class TFCvtAttention(tf.keras.layers.Layer):
"""Attention layer. First chunk of the convolutional transformer block."""
@ -411,6 +495,17 @@ class TFCvtAttention(tf.keras.layers.Layer):
attention_output = self.dense_output(self_output, training=training)
return attention_output
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "attention", None) is not None:
with tf.name_scope(self.attention.name):
self.attention.build(None)
if getattr(self, "dense_output", None) is not None:
with tf.name_scope(self.dense_output.name):
self.dense_output.build(None)
class TFCvtIntermediate(tf.keras.layers.Layer):
"""Intermediate dense layer. Second chunk of the convolutional transformer block."""
@ -423,23 +518,34 @@ class TFCvtIntermediate(tf.keras.layers.Layer):
activation="gelu",
name="dense",
)
self.embed_dim = embed_dim
def call(self, hidden_state: tf.Tensor) -> tf.Tensor:
hidden_state = self.dense(hidden_state)
return hidden_state
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.embed_dim])
class TFCvtOutput(tf.keras.layers.Layer):
"""
Output of the Convolutional Transformer Block (last chunk). It consists of a MLP and a residual connection.
"""
def __init__(self, config: CvtConfig, embed_dim: int, drop_rate: int, **kwargs):
def __init__(self, config: CvtConfig, embed_dim: int, mlp_ratio: int, drop_rate: int, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(
units=embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
self.dropout = tf.keras.layers.Dropout(drop_rate)
self.embed_dim = embed_dim
self.mlp_ratio = mlp_ratio
def call(self, hidden_state: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_state = self.dense(inputs=hidden_state)
@ -447,6 +553,14 @@ class TFCvtOutput(tf.keras.layers.Layer):
hidden_state = hidden_state + input_tensor
return hidden_state
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, int(self.embed_dim * self.mlp_ratio)])
class TFCvtLayer(tf.keras.layers.Layer):
"""
@ -492,7 +606,7 @@ class TFCvtLayer(tf.keras.layers.Layer):
name="attention",
)
self.intermediate = TFCvtIntermediate(config, embed_dim, mlp_ratio, name="intermediate")
self.dense_output = TFCvtOutput(config, embed_dim, drop_rate, name="output")
self.dense_output = TFCvtOutput(config, embed_dim, mlp_ratio, drop_rate, name="output")
# Using `layers.Activation` instead of `tf.identity` to better control `training` behaviour.
self.drop_path = (
TFCvtDropPath(drop_path_rate, name="drop_path")
@ -502,6 +616,7 @@ class TFCvtLayer(tf.keras.layers.Layer):
# Using the same default epsilon as PyTorch
self.layernorm_before = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_before")
self.layernorm_after = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_after")
self.embed_dim = embed_dim
def call(self, hidden_state: tf.Tensor, height: int, width: int, training: bool = False) -> tf.Tensor:
# in Cvt, layernorm is applied before self-attention
@ -520,6 +635,29 @@ class TFCvtLayer(tf.keras.layers.Layer):
layer_output = self.drop_path(layer_output, training=training)
return layer_output
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "attention", None) is not None:
with tf.name_scope(self.attention.name):
self.attention.build(None)
if getattr(self, "intermediate", None) is not None:
with tf.name_scope(self.intermediate.name):
self.intermediate.build(None)
if getattr(self, "dense_output", None) is not None:
with tf.name_scope(self.dense_output.name):
self.dense_output.build(None)
if getattr(self, "drop_path", None) is not None:
with tf.name_scope(self.drop_path.name):
self.drop_path.build(None)
if getattr(self, "layernorm_before", None) is not None:
with tf.name_scope(self.layernorm_before.name):
self.layernorm_before.build([None, None, self.embed_dim])
if getattr(self, "layernorm_after", None) is not None:
with tf.name_scope(self.layernorm_after.name):
self.layernorm_after.build([None, None, self.embed_dim])
class TFCvtStage(tf.keras.layers.Layer):
"""
@ -548,6 +686,7 @@ class TFCvtStage(tf.keras.layers.Layer):
self.embedding = TFCvtEmbeddings(
self.config,
patch_size=config.patch_sizes[self.stage],
num_channels=config.num_channels if self.stage == 0 else config.embed_dim[self.stage - 1],
stride=config.patch_stride[self.stage],
embed_dim=config.embed_dim[self.stage],
padding=config.patch_padding[self.stage],
@ -603,6 +742,18 @@ class TFCvtStage(tf.keras.layers.Layer):
hidden_state = tf.reshape(hidden_state, shape=(batch_size, height, width, num_channels))
return hidden_state, cls_token
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embedding", None) is not None:
with tf.name_scope(self.embedding.name):
self.embedding.build(None)
if getattr(self, "layers", None) is not None:
for layer in self.layers:
with tf.name_scope(layer.name):
layer.build(None)
class TFCvtEncoder(tf.keras.layers.Layer):
"""
@ -655,6 +806,15 @@ class TFCvtEncoder(tf.keras.layers.Layer):
hidden_states=all_hidden_states,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "stages", None) is not None:
for layer in self.stages:
with tf.name_scope(layer.name):
layer.build(None)
@keras_serializable
class TFCvtMainLayer(tf.keras.layers.Layer):
@ -696,6 +856,14 @@ class TFCvtMainLayer(tf.keras.layers.Layer):
hidden_states=encoder_outputs.hidden_states,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
class TFCvtPreTrainedModel(TFPreTrainedModel):
"""
@ -815,6 +983,14 @@ class TFCvtModel(TFCvtPreTrainedModel):
hidden_states=outputs.hidden_states,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "cvt", None) is not None:
with tf.name_scope(self.cvt.name):
self.cvt.build(None)
@add_start_docstrings(
"""
@ -840,6 +1016,7 @@ class TFCvtForImageClassification(TFCvtPreTrainedModel, TFSequenceClassification
bias_initializer="zeros",
name="classifier",
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(TFCVT_INPUTS_DOCSTRING)
@ -909,3 +1086,18 @@ class TFCvtForImageClassification(TFCvtPreTrainedModel, TFSequenceClassification
return ((loss,) + output) if loss is not None else output
return TFImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "cvt", None) is not None:
with tf.name_scope(self.cvt.name):
self.cvt.build(None)
if getattr(self, "layernorm", None) is not None:
with tf.name_scope(self.layernorm.name):
self.layernorm.build([None, None, self.config.embed_dim[-1]])
if getattr(self, "classifier", None) is not None:
if hasattr(self.classifier, "name"):
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.embed_dim[-1]])

View File

@ -137,7 +137,7 @@ class TFData2VecVisionEmbeddings(tf.keras.layers.Layer):
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
def build(self, input_shape: tf.TensorShape):
def build(self, input_shape=None):
self.cls_token = self.add_weight(
shape=(1, 1, self.config.hidden_size),
initializer=tf.random_normal_initializer(stddev=self.config.initializer_range),
@ -164,7 +164,12 @@ class TFData2VecVisionEmbeddings(tf.keras.layers.Layer):
else:
self.position_embeddings = None
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "patch_embeddings", None) is not None:
with tf.name_scope(self.patch_embeddings.name):
self.patch_embeddings.build(None)
def call(self, pixel_values: tf.Tensor, bool_masked_pos: tf.Tensor | None = None) -> tf.Tensor:
embeddings = self.patch_embeddings(pixel_values)
@ -248,6 +253,14 @@ class TFData2VecVisionPatchEmbeddings(tf.keras.layers.Layer):
return tf.reshape(tensor=projection, shape=(batch_size, num_patches, -1))
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "projection", None) is not None:
with tf.name_scope(self.projection.name):
self.projection.build([None, None, None, self.num_channels])
class TFData2VecVisionSelfAttention(tf.keras.layers.Layer):
def __init__(self, config: Data2VecVisionConfig, window_size: Optional[tuple] = None, **kwargs):
@ -284,6 +297,7 @@ class TFData2VecVisionSelfAttention(tf.keras.layers.Layer):
)
else:
self.relative_position_bias = None
self.config = config
def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
@ -344,6 +358,23 @@ class TFData2VecVisionSelfAttention(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "query", None) is not None:
with tf.name_scope(self.query.name):
self.query.build([None, None, self.config.hidden_size])
if getattr(self, "key", None) is not None:
with tf.name_scope(self.key.name):
self.key.build([None, None, self.config.hidden_size])
if getattr(self, "value", None) is not None:
with tf.name_scope(self.value.name):
self.value.build([None, None, self.config.hidden_size])
if getattr(self, "relative_position_bias", None) is not None:
with tf.name_scope(self.relative_position_bias.name):
self.relative_position_bias.build(None)
class TFData2VecVisionSelfOutput(tf.keras.layers.Layer):
"""
@ -358,6 +389,7 @@ class TFData2VecVisionSelfOutput(tf.keras.layers.Layer):
units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
self.config = config
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, gamma=None, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -365,6 +397,14 @@ class TFData2VecVisionSelfOutput(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
class TFData2VecVisionAttention(tf.keras.layers.Layer):
def __init__(self, config: Data2VecVisionConfig, window_size: Optional[tuple] = None, **kwargs):
@ -398,6 +438,17 @@ class TFData2VecVisionAttention(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "attention", None) is not None:
with tf.name_scope(self.attention.name):
self.attention.build(None)
if getattr(self, "dense_output", None) is not None:
with tf.name_scope(self.dense_output.name):
self.dense_output.build(None)
# Copied from transformers.models.vit.modeling_tf_vit.TFViTIntermediate with ViT->Data2VecVision
class TFData2VecVisionIntermediate(tf.keras.layers.Layer):
@ -412,6 +463,7 @@ class TFData2VecVisionIntermediate(tf.keras.layers.Layer):
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
else:
self.intermediate_act_fn = config.hidden_act
self.config = config
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -419,6 +471,14 @@ class TFData2VecVisionIntermediate(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
class TFData2VecVisionOutput(tf.keras.layers.Layer):
def __init__(self, config: Data2VecVisionConfig, **kwargs):
@ -428,6 +488,7 @@ class TFData2VecVisionOutput(tf.keras.layers.Layer):
units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
self.config = config
def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -435,6 +496,14 @@ class TFData2VecVisionOutput(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.intermediate_size])
class TFData2VecVisionLayer(tf.keras.layers.Layer):
"""This corresponds to the Block class in the timm implementation."""
@ -483,7 +552,27 @@ class TFData2VecVisionLayer(tf.keras.layers.Layer):
else:
self.lambda_1, self.lambda_2 = None, None
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "attention", None) is not None:
with tf.name_scope(self.attention.name):
self.attention.build(None)
if getattr(self, "intermediate", None) is not None:
with tf.name_scope(self.intermediate.name):
self.intermediate.build(None)
if getattr(self, "data2vec_output", None) is not None:
with tf.name_scope(self.data2vec_output.name):
self.data2vec_output.build(None)
if getattr(self, "layernorm_before", None) is not None:
with tf.name_scope(self.layernorm_before.name):
self.layernorm_before.build([None, None, self.config.hidden_size])
if getattr(self, "layernorm_after", None) is not None:
with tf.name_scope(self.layernorm_after.name):
self.layernorm_after.build([None, None, self.config.hidden_size])
if getattr(self, "drop_path", None) is not None:
with tf.name_scope(self.drop_path.name):
self.drop_path.build(None)
def call(
self,
@ -650,6 +739,18 @@ class TFData2VecVisionEncoder(tf.keras.layers.Layer):
attentions=all_self_attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "relative_position_bias", None) is not None:
with tf.name_scope(self.relative_position_bias.name):
self.relative_position_bias.build(None)
if getattr(self, "layer", None) is not None:
for layer in self.layer:
with tf.name_scope(layer.name):
layer.build(None)
@keras_serializable
class TFData2VecVisionMainLayer(tf.keras.layers.Layer):
@ -741,6 +842,24 @@ class TFData2VecVisionMainLayer(tf.keras.layers.Layer):
attentions=encoder_outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embeddings", None) is not None:
with tf.name_scope(self.embeddings.name):
self.embeddings.build(None)
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
if getattr(self, "layernorm", None) is not None:
if hasattr(self.layernorm, "name"):
with tf.name_scope(self.layernorm.name):
self.layernorm.build((None, self.config.hidden_size))
if getattr(self, "pooler", None) is not None:
with tf.name_scope(self.pooler.name):
self.pooler.build(None)
class TFData2VecVisionPooler(tf.keras.layers.Layer):
def __init__(self, config: Data2VecVisionConfig, **kwargs):
@ -750,6 +869,7 @@ class TFData2VecVisionPooler(tf.keras.layers.Layer):
if config.use_mean_pooling
else None
)
self.config = config
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
if self.layernorm is not None:
@ -762,6 +882,15 @@ class TFData2VecVisionPooler(tf.keras.layers.Layer):
return pooled_output
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "layernorm", None) is not None:
if hasattr(self.layernorm, "name"):
with tf.name_scope(self.layernorm.name):
self.layernorm.build((None, self.config.hidden_size))
class TFData2VecVisionPreTrainedModel(TFPreTrainedModel):
"""
@ -896,6 +1025,14 @@ class TFData2VecVisionModel(TFData2VecVisionPreTrainedModel):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "data2vec_vision", None) is not None:
with tf.name_scope(self.data2vec_vision.name):
self.data2vec_vision.build(None)
@add_start_docstrings(
"""
@ -917,6 +1054,7 @@ class TFData2VecVisionForImageClassification(TFData2VecVisionPreTrainedModel, TF
kernel_initializer=get_initializer(config.initializer_range),
name="classifier",
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(DATA2VEC_VISION_INPUTS_DOCSTRING)
@ -968,6 +1106,17 @@ class TFData2VecVisionForImageClassification(TFData2VecVisionPreTrainedModel, TF
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "data2vec_vision", None) is not None:
with tf.name_scope(self.data2vec_vision.name):
self.data2vec_vision.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.hidden_size])
class TFData2VecVisionConvModule(tf.keras.layers.Layer):
"""
@ -979,6 +1128,7 @@ class TFData2VecVisionConvModule(tf.keras.layers.Layer):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
padding: str = "valid",
@ -997,6 +1147,8 @@ class TFData2VecVisionConvModule(tf.keras.layers.Layer):
)
self.bn = tf.keras.layers.BatchNormalization(name="bn", momentum=0.9, epsilon=1e-5)
self.activation = tf.nn.relu
self.in_channels = in_channels
self.out_channels = out_channels
def call(self, input: tf.Tensor) -> tf.Tensor:
output = self.conv(input)
@ -1004,88 +1156,140 @@ class TFData2VecVisionConvModule(tf.keras.layers.Layer):
output = self.activation(output)
return output
# Copied from:
# https://gist.github.com/Rocketknight1/43abbe6e73f1008e6e459486e01e0ceb
class TFAdaptiveAvgPool1D(tf.keras.layers.Layer):
def __init__(self, output_dim, mode="dense", **kwargs):
super().__init__(**kwargs)
self.output_dim = output_dim
self.mode = mode
self.map = None
def build(self, input_shape):
super().build(input_shape)
"""We pre-compute the sparse matrix for the build() step once. The below code comes
from https://stackoverflow.com/questions/53841509/how-does-adaptive-pooling-in-pytorch-work/63603993#63603993."""
def get_kernels(ind, outd) -> List:
"""Returns a List [(kernel_offset_start,kernel_length)] defining all the pooling kernels for a 1-D adaptive
pooling layer that takes an input of dimension `ind` and yields an output of dimension `outd`"""
def start_index(a, b, c):
return math.floor((float(a) * float(c)) / b)
def end_index(a, b, c):
return math.ceil((float(a + 1) * float(c)) / b)
results = []
for ow in range(outd):
start = start_index(ow, outd, ind)
end = end_index(ow, outd, ind)
sz = end - start
results.append((start, sz))
return results
in_dim = int(input_shape[-1])
kernels = get_kernels(in_dim, self.output_dim)
sparse_map = np.zeros((in_dim, self.output_dim), dtype=np.float32)
for i, kernel in enumerate(kernels):
sparse_map[kernel[0] : kernel[0] + kernel[1], i] = 1 / kernel[1]
if self.mode == "dense":
self.map = tf.constant(sparse_map)
else:
self.map = tf.sparse.from_dense(sparse_map)
def call(self, inputs):
if self.mode == "dense":
return inputs @ self.map
else:
input_dims = inputs.shape
input_matrix = tf.reshape(inputs, (-1, input_dims[-1]))
out = tf.sparse.sparse_dense_matmul(input_matrix, self.map)
return tf.reshape(out, input_dims[:-1].as_list() + [-1])
def get_config(self):
config = super().get_config()
config.update({"output_dim": self.output_dim, "mode": self.mode})
return config
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "conv", None) is not None:
with tf.name_scope(self.conv.name):
self.conv.build([None, None, None, self.in_channels])
if getattr(self, "bn", None) is not None:
with tf.name_scope(self.bn.name):
self.bn.build((None, None, None, self.out_channels))
class TFAdaptiveAvgPool2D(tf.keras.layers.Layer):
def __init__(self, output_shape, mode="dense", **kwargs):
def __init__(self, output_dims: Tuple[int, int], input_ordering: str = "NHWC", **kwargs):
super().__init__(**kwargs)
self.mode = mode
self.h_pool = TFAdaptiveAvgPool1D(output_shape[0], mode=mode, name="h_pool")
self.w_pool = TFAdaptiveAvgPool1D(output_shape[1], mode=mode, name="w_pool")
self.output_dims = output_dims
self.input_ordering = input_ordering
if input_ordering not in ("NCHW", "NHWC"):
raise ValueError("Unrecognized input_ordering, should be 'NCHW' or 'NHWC'!")
self.h_axis = input_ordering.index("H")
self.w_axis = input_ordering.index("W")
def call(self, inputs):
# Rearrange from NHWC -> NCHW
inputs = tf.transpose(inputs, perm=[0, 3, 1, 2])
# Perform W-pooling
inputs = self.w_pool(inputs)
# Rearrange NCHW -> NCWH
inputs = tf.transpose(inputs, perm=[0, 1, 3, 2])
# Perform H-pooling
inputs = self.h_pool(inputs)
# Rearrange from NCWH -> NHWC
inputs = tf.transpose(inputs, perm=[0, 3, 2, 1])
return inputs
def pseudo_1d_pool(self, inputs: tf.Tensor, h_pooling: bool):
# Figure out which axis we're pooling on
if h_pooling:
axis = self.h_axis
output_dim = self.output_dims[0]
else:
axis = self.w_axis
output_dim = self.output_dims[1]
input_dim = inputs.shape[axis]
def get_config(self):
config = super().get_config()
config.update({"mode": self.mode})
return config
# Figure out the potential pooling windows
# This is the key idea - the torch op always uses only two
# consecutive pooling window sizes, like 3 and 4. Therefore,
# if we pool with both possible sizes, we simply need to gather
# the 'correct' pool at each position to reimplement the torch op.
small_window = math.ceil(input_dim / output_dim)
big_window = small_window + 1
if h_pooling:
output_dim = self.output_dims[0]
small_window_shape = (small_window, 1)
big_window_shape = (big_window, 1)
else:
output_dim = self.output_dims[1]
small_window_shape = (1, small_window)
big_window_shape = (1, big_window)
# For resizes to 1, or integer resizes, we can take quick shortcuts
if output_dim == input_dim:
return inputs
elif output_dim == 1:
return tf.reduce_mean(inputs, axis=axis, keepdims=True)
elif input_dim % output_dim == 0:
return tf.nn.avg_pool2d(
inputs,
ksize=small_window_shape,
strides=small_window_shape,
padding="VALID",
data_format=self.input_ordering,
)
# When upscaling by an integer factor we can also take a quick shortcut
elif output_dim > input_dim and output_dim % input_dim == 0:
return tf.repeat(inputs, repeats=output_dim // input_dim, axis=axis)
# For non-integer resizes, we pool with both possible window sizes and concatenate them
if output_dim < input_dim:
small_pool = tf.nn.avg_pool2d(
inputs, ksize=small_window_shape, strides=1, padding="VALID", data_format=self.input_ordering
)
big_pool = tf.nn.avg_pool2d(
inputs, ksize=big_window_shape, strides=1, padding="VALID", data_format=self.input_ordering
)
both_pool = tf.concat([small_pool, big_pool], axis=axis)
else:
# When we're actually upscaling instead, then we build the pools a bit differently
small_pool = inputs
big_pool = tf.nn.avg_pool2d(
inputs, ksize=big_window_shape, strides=1, padding="VALID", data_format=self.input_ordering
)
both_pool = tf.concat([small_pool, big_pool], axis=axis)
# We compute vectors of the start and end positions for each pooling window
# Each (start, end) pair here corresponds to a single output position
window_starts = tf.math.floor((tf.range(output_dim, dtype=tf.float32) * input_dim) / output_dim)
window_starts = tf.cast(window_starts, tf.int64)
window_ends = tf.math.ceil((tf.range(1, output_dim + 1, dtype=tf.float32) * input_dim) / output_dim)
window_ends = tf.cast(window_ends, tf.int64)
# pool_selector is a boolean array of shape (output_dim,) where 1 indicates that output position
# has a big receptive field and 0 indicates that that output position has a small receptive field
pool_selector = tf.cast(window_ends - window_starts - small_window, tf.bool)
# Since we concatenated the small and big pools, we need to do a bit of
# pointer arithmetic to get the indices of the big pools
small_indices = window_starts
big_indices = window_starts + small_pool.shape[axis]
# Finally, we use the pool_selector to generate a list of indices, one per output position
gather_indices = tf.where(pool_selector, big_indices, small_indices)
# Gathering from those indices yields the final, correct pooling
return tf.gather(both_pool, gather_indices, axis=axis)
def call(self, inputs: tf.Tensor):
if self.input_ordering == "NHWC":
input_shape = inputs.shape[1:3]
else:
input_shape = inputs.shape[2:]
# We break the task down into each possible case
# Firstly, if we're resizing down to 1, it's just tf.reduce_mean
if self.output_dims[0] == self.output_dims[1] == 1:
if self.input_ordering == "NHWC":
reduce_dims = [1, 2]
else:
reduce_dims = [2, 3]
return tf.reduce_mean(inputs, axis=reduce_dims, keepdims=True)
# Secondly, if we're resizing by an integer factor on both dimensions, we can take a quick shortcut
elif input_shape[0] % self.output_dims[0] == 0 and input_shape[1] % self.output_dims[1] == 0:
h_resize = int(input_shape[0] // self.output_dims[0])
w_resize = int(input_shape[1] // self.output_dims[1])
return tf.nn.avg_pool2d(
inputs,
ksize=(h_resize, w_resize),
strides=(h_resize, w_resize),
padding="VALID",
data_format=self.input_ordering,
)
else:
# Finally, if we can't take the shortcut, we do a 1D pool on each axis. pseudo_1d_pool will take a shortcut
# for dimensions where an integer resize is possible. It can also handle upscaling.
h_pooled = self.pseudo_1d_pool(inputs, h_pooling=True)
return self.pseudo_1d_pool(h_pooled, h_pooling=False)
class TFData2VecVisionPyramidPoolingModule(tf.keras.layers.Layer):
@ -1100,18 +1304,21 @@ class TFData2VecVisionPyramidPoolingModule(tf.keras.layers.Layer):
Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
"""
def __init__(self, pool_scales: Tuple[int, ...], channels: int, **kwargs) -> None:
def __init__(self, pool_scales: Tuple[int, ...], in_channels: int, out_channels: int, **kwargs) -> None:
super().__init__(**kwargs)
self.pool_scales = pool_scales
self.channels = channels
self.in_channels = in_channels
self.out_channels = out_channels
self.layer_list = []
for idx, pool_scale in enumerate(pool_scales):
pool_scale = pool_scale if isinstance(pool_scale, collections.abc.Iterable) else (pool_scale, pool_scale)
self.layer_list.append(
[
TFAdaptiveAvgPool2D(output_shape=pool_scale),
TFData2VecVisionConvModule(out_channels=self.channels, kernel_size=1, name=f"{idx}.1"),
TFAdaptiveAvgPool2D(output_dims=pool_scale),
TFData2VecVisionConvModule(
in_channels=in_channels, out_channels=self.out_channels, kernel_size=1, name=f"{idx}.1"
),
]
)
@ -1128,6 +1335,12 @@ class TFData2VecVisionPyramidPoolingModule(tf.keras.layers.Layer):
ppm_outs.append(upsampled_ppm_out)
return ppm_outs
def build(self, input_shape=None):
for layer in self.layer_list:
for layer_module in layer:
with tf.name_scope(layer_module.name):
layer_module.build(None)
class TFData2VecVisionUperHead(tf.keras.layers.Layer):
"""
@ -1146,21 +1359,39 @@ class TFData2VecVisionUperHead(tf.keras.layers.Layer):
self.classifier = tf.keras.layers.Conv2D(config.num_labels, kernel_size=1, name="classifier")
# PSP Module
self.psp_modules = TFData2VecVisionPyramidPoolingModule(self.pool_scales, self.channels, name="psp_modules")
self.bottleneck = TFData2VecVisionConvModule(self.channels, kernel_size=3, padding="same", name="bottleneck")
self.psp_modules = TFData2VecVisionPyramidPoolingModule(
self.pool_scales, self.in_channels[-1], self.channels, name="psp_modules"
)
self.bottleneck = TFData2VecVisionConvModule(
self.in_channels[-1] + len(self.pool_scales) * self.channels,
self.channels,
kernel_size=3,
padding="same",
name="bottleneck",
)
# FPN Module
self.lateral_convs = []
self.fpn_convs = []
for idx, _ in enumerate(self.in_channels[:-1]): # skip the top layer
l_conv = TFData2VecVisionConvModule(out_channels=self.channels, kernel_size=1, name=f"lateral_convs.{idx}")
for idx, in_channels in enumerate(self.in_channels[:-1]): # skip the top layer
l_conv = TFData2VecVisionConvModule(
in_channels, out_channels=self.channels, kernel_size=1, name=f"lateral_convs.{idx}"
)
fpn_conv = TFData2VecVisionConvModule(
out_channels=self.channels, kernel_size=3, padding="same", name=f"fpn_convs.{idx}"
in_channels=self.channels,
out_channels=self.channels,
kernel_size=3,
padding="same",
name=f"fpn_convs.{idx}",
)
self.lateral_convs.append(l_conv)
self.fpn_convs.append(fpn_conv)
self.fpn_bottleneck = TFData2VecVisionConvModule(
out_channels=self.channels, kernel_size=3, padding="same", name="fpn_bottleneck"
in_channels=len(self.in_channels) * self.channels,
out_channels=self.channels,
kernel_size=3,
padding="same",
name="fpn_bottleneck",
)
def psp_forward(self, inputs):
@ -1197,6 +1428,29 @@ class TFData2VecVisionUperHead(tf.keras.layers.Layer):
return output
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, None, self.channels])
if getattr(self, "psp_modules", None) is not None:
with tf.name_scope(self.psp_modules.name):
self.psp_modules.build(None)
if getattr(self, "bottleneck", None) is not None:
with tf.name_scope(self.bottleneck.name):
self.bottleneck.build(None)
if getattr(self, "fpn_bottleneck", None) is not None:
with tf.name_scope(self.fpn_bottleneck.name):
self.fpn_bottleneck.build(None)
for layer in self.lateral_convs:
with tf.name_scope(layer.name):
layer.build(None)
for layer in self.fpn_convs:
with tf.name_scope(layer.name):
layer.build(None)
class TFData2VecVisionFCNHead(tf.keras.layers.Layer):
"""
@ -1230,6 +1484,7 @@ class TFData2VecVisionFCNHead(tf.keras.layers.Layer):
convs = []
convs.append(
TFData2VecVisionConvModule(
in_channels=self.in_channels,
out_channels=self.channels,
kernel_size=kernel_size,
padding="same",
@ -1240,6 +1495,7 @@ class TFData2VecVisionFCNHead(tf.keras.layers.Layer):
for i in range(self.num_convs - 1):
convs.append(
TFData2VecVisionConvModule(
in_channels=self.channels,
out_channels=self.channels,
kernel_size=kernel_size,
padding="same",
@ -1253,7 +1509,11 @@ class TFData2VecVisionFCNHead(tf.keras.layers.Layer):
self.convs = convs
if self.concat_input:
self.conv_cat = TFData2VecVisionConvModule(
out_channels=self.channels, kernel_size=kernel_size, padding="same", name="conv_cat"
self.in_channels + self.channels,
out_channels=self.channels,
kernel_size=kernel_size,
padding="same",
name="conv_cat",
)
self.classifier = tf.keras.layers.Conv2D(config.num_labels, kernel_size=1, name="classifier")
@ -1269,6 +1529,17 @@ class TFData2VecVisionFCNHead(tf.keras.layers.Layer):
output = self.classifier(output)
return output
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, None, self.channels])
if getattr(self, "conv_cat", None) is not None:
with tf.name_scope(self.conv_cat.name):
self.conv_cat.build(None)
@add_start_docstrings(
"""
@ -1428,3 +1699,27 @@ class TFData2VecVisionForSemanticSegmentation(TFData2VecVisionPreTrainedModel):
hidden_states=outputs.hidden_states if output_hidden_states else None,
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "data2vec_vision", None) is not None:
with tf.name_scope(self.data2vec_vision.name):
self.data2vec_vision.build(None)
if getattr(self, "decode_head", None) is not None:
with tf.name_scope(self.decode_head.name):
self.decode_head.build(None)
if getattr(self, "auxiliary_head", None) is not None:
with tf.name_scope(self.auxiliary_head.name):
self.auxiliary_head.build(None)
if getattr(self, "fpn1", None) is not None:
with tf.name_scope(self.fpn1[0].name):
self.fpn1[0].build([None, None, None, self.config.hidden_size])
with tf.name_scope(self.fpn1[1].name):
self.fpn1[1].build((None, None, None, self.config.hidden_size))
with tf.name_scope(self.fpn1[3].name):
self.fpn1[3].build([None, None, None, self.config.hidden_size])
if getattr(self, "fpn2", None) is not None:
with tf.name_scope(self.fpn2[0].name):
self.fpn2[0].build([None, None, None, self.config.hidden_size])

View File

@ -78,6 +78,17 @@ class TFDebertaContextPooler(tf.keras.layers.Layer):
def output_dim(self) -> int:
return self.config.hidden_size
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.pooler_hidden_size])
if getattr(self, "dropout", None) is not None:
with tf.name_scope(self.dropout.name):
self.dropout.build(None)
class TFDebertaXSoftmax(tf.keras.layers.Layer):
"""
@ -167,6 +178,7 @@ class TFDebertaSelfOutput(tf.keras.layers.Layer):
self.dense = tf.keras.layers.Dense(config.hidden_size, name="dense")
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = TFDebertaStableDropout(config.hidden_dropout_prob, name="dropout")
self.config = config
def call(self, hidden_states, input_tensor, training: bool = False):
hidden_states = self.dense(hidden_states)
@ -174,6 +186,20 @@ class TFDebertaSelfOutput(tf.keras.layers.Layer):
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
if getattr(self, "dropout", None) is not None:
with tf.name_scope(self.dropout.name):
self.dropout.build(None)
class TFDebertaAttention(tf.keras.layers.Layer):
def __init__(self, config: DebertaConfig, **kwargs):
@ -211,6 +237,17 @@ class TFDebertaAttention(tf.keras.layers.Layer):
return output
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "self", None) is not None:
with tf.name_scope(self.self.name):
self.self.build(None)
if getattr(self, "dense_output", None) is not None:
with tf.name_scope(self.dense_output.name):
self.dense_output.build(None)
class TFDebertaIntermediate(tf.keras.layers.Layer):
def __init__(self, config: DebertaConfig, **kwargs):
@ -224,6 +261,7 @@ class TFDebertaIntermediate(tf.keras.layers.Layer):
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
else:
self.intermediate_act_fn = config.hidden_act
self.config = config
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -231,6 +269,14 @@ class TFDebertaIntermediate(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
class TFDebertaOutput(tf.keras.layers.Layer):
def __init__(self, config: DebertaConfig, **kwargs):
@ -241,6 +287,7 @@ class TFDebertaOutput(tf.keras.layers.Layer):
)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = TFDebertaStableDropout(config.hidden_dropout_prob, name="dropout")
self.config = config
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -249,6 +296,20 @@ class TFDebertaOutput(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.intermediate_size])
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
if getattr(self, "dropout", None) is not None:
with tf.name_scope(self.dropout.name):
self.dropout.build(None)
class TFDebertaLayer(tf.keras.layers.Layer):
def __init__(self, config: DebertaConfig, **kwargs):
@ -286,6 +347,20 @@ class TFDebertaLayer(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "attention", None) is not None:
with tf.name_scope(self.attention.name):
self.attention.build(None)
if getattr(self, "intermediate", None) is not None:
with tf.name_scope(self.intermediate.name):
self.intermediate.build(None)
if getattr(self, "bert_output", None) is not None:
with tf.name_scope(self.bert_output.name):
self.bert_output.build(None)
class TFDebertaEncoder(tf.keras.layers.Layer):
def __init__(self, config: DebertaConfig, **kwargs):
@ -299,14 +374,20 @@ class TFDebertaEncoder(tf.keras.layers.Layer):
if self.max_relative_positions < 1:
self.max_relative_positions = config.max_position_embeddings
def build(self, input_shape):
def build(self, input_shape=None):
if self.built:
return
self.built = True
if self.relative_attention:
self.rel_embeddings = self.add_weight(
name="rel_embeddings.weight",
shape=[self.max_relative_positions * 2, self.config.hidden_size],
initializer=get_initializer(self.config.initializer_range),
)
return super().build(input_shape)
if getattr(self, "layer", None) is not None:
for layer in self.layer:
with tf.name_scope(layer.name):
layer.build(None)
def get_rel_embedding(self):
rel_embeddings = self.rel_embeddings if self.relative_attention else None
@ -528,15 +609,39 @@ class TFDebertaDisentangledSelfAttention(tf.keras.layers.Layer):
)
self.dropout = TFDebertaStableDropout(config.attention_probs_dropout_prob, name="dropout")
self.config = config
def build(self, input_shape):
def build(self, input_shape=None):
if self.built:
return
self.built = True
self.q_bias = self.add_weight(
name="q_bias", shape=(self.all_head_size), initializer=tf.keras.initializers.Zeros()
)
self.v_bias = self.add_weight(
name="v_bias", shape=(self.all_head_size), initializer=tf.keras.initializers.Zeros()
)
return super().build(input_shape)
if getattr(self, "in_proj", None) is not None:
with tf.name_scope(self.in_proj.name):
self.in_proj.build([None, None, self.config.hidden_size])
if getattr(self, "dropout", None) is not None:
with tf.name_scope(self.dropout.name):
self.dropout.build(None)
if getattr(self, "head_logits_proj", None) is not None:
with tf.name_scope(self.head_logits_proj.name):
self.head_logits_proj.build(None)
if getattr(self, "head_weights_proj", None) is not None:
with tf.name_scope(self.head_weights_proj.name):
self.head_weights_proj.build(None)
if getattr(self, "pos_dropout", None) is not None:
with tf.name_scope(self.pos_dropout.name):
self.pos_dropout.build(None)
if getattr(self, "pos_proj", None) is not None:
with tf.name_scope(self.pos_proj.name):
self.pos_proj.build(None)
if getattr(self, "pos_q_proj", None) is not None:
with tf.name_scope(self.pos_q_proj.name):
self.pos_q_proj.build(None)
def transpose_for_scores(self, tensor: tf.Tensor) -> tf.Tensor:
shape = shape_list(tensor)[:-1] + [self.num_attention_heads, -1]
@ -735,7 +840,7 @@ class TFDebertaEmbeddings(tf.keras.layers.Layer):
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = TFDebertaStableDropout(config.hidden_dropout_prob, name="dropout")
def build(self, input_shape: tf.TensorShape):
def build(self, input_shape=None):
with tf.name_scope("word_embeddings"):
self.weight = self.add_weight(
name="weight",
@ -763,7 +868,18 @@ class TFDebertaEmbeddings(tf.keras.layers.Layer):
else:
self.position_embeddings = None
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
if getattr(self, "dropout", None) is not None:
with tf.name_scope(self.dropout.name):
self.dropout.build(None)
if getattr(self, "embed_proj", None) is not None:
with tf.name_scope(self.embed_proj.name):
self.embed_proj.build([None, None, self.embedding_size])
def call(
self,
@ -838,6 +954,7 @@ class TFDebertaPredictionHeadTransform(tf.keras.layers.Layer):
else:
self.transform_act_fn = config.hidden_act
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.config = config
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -846,6 +963,17 @@ class TFDebertaPredictionHeadTransform(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.embedding_size])
class TFDebertaLMPredictionHead(tf.keras.layers.Layer):
def __init__(self, config: DebertaConfig, input_embeddings: tf.keras.layers.Layer, **kwargs):
@ -860,10 +988,15 @@ class TFDebertaLMPredictionHead(tf.keras.layers.Layer):
# an output-only bias for each token.
self.input_embeddings = input_embeddings
def build(self, input_shape: tf.TensorShape):
def build(self, input_shape=None):
self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias")
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "transform", None) is not None:
with tf.name_scope(self.transform.name):
self.transform.build(None)
def get_output_embeddings(self) -> tf.keras.layers.Layer:
return self.input_embeddings
@ -900,6 +1033,14 @@ class TFDebertaOnlyMLMHead(tf.keras.layers.Layer):
return prediction_scores
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "predictions", None) is not None:
with tf.name_scope(self.predictions.name):
self.predictions.build(None)
# @keras_serializable
class TFDebertaMainLayer(tf.keras.layers.Layer):
@ -984,6 +1125,17 @@ class TFDebertaMainLayer(tf.keras.layers.Layer):
attentions=encoder_outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embeddings", None) is not None:
with tf.name_scope(self.embeddings.name):
self.embeddings.build(None)
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
class TFDebertaPreTrainedModel(TFPreTrainedModel):
"""
@ -1124,6 +1276,14 @@ class TFDebertaModel(TFDebertaPreTrainedModel):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "deberta", None) is not None:
with tf.name_scope(self.deberta.name):
self.deberta.build(None)
@add_start_docstrings("""DeBERTa Model with a `language modeling` head on top.""", DEBERTA_START_DOCSTRING)
class TFDebertaForMaskedLM(TFDebertaPreTrainedModel, TFMaskedLanguageModelingLoss):
@ -1194,6 +1354,17 @@ class TFDebertaForMaskedLM(TFDebertaPreTrainedModel, TFMaskedLanguageModelingLos
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "deberta", None) is not None:
with tf.name_scope(self.deberta.name):
self.deberta.build(None)
if getattr(self, "mlm", None) is not None:
with tf.name_scope(self.mlm.name):
self.mlm.build(None)
@add_start_docstrings(
"""
@ -1219,6 +1390,7 @@ class TFDebertaForSequenceClassification(TFDebertaPreTrainedModel, TFSequenceCla
kernel_initializer=get_initializer(config.initializer_range),
name="classifier",
)
self.output_dim = self.pooler.output_dim
@unpack_inputs
@add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ -1275,6 +1447,23 @@ class TFDebertaForSequenceClassification(TFDebertaPreTrainedModel, TFSequenceCla
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "deberta", None) is not None:
with tf.name_scope(self.deberta.name):
self.deberta.build(None)
if getattr(self, "pooler", None) is not None:
with tf.name_scope(self.pooler.name):
self.pooler.build(None)
if getattr(self, "dropout", None) is not None:
with tf.name_scope(self.dropout.name):
self.dropout.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.output_dim])
@add_start_docstrings(
"""
@ -1294,6 +1483,7 @@ class TFDebertaForTokenClassification(TFDebertaPreTrainedModel, TFTokenClassific
self.classifier = tf.keras.layers.Dense(
units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ -1346,6 +1536,17 @@ class TFDebertaForTokenClassification(TFDebertaPreTrainedModel, TFTokenClassific
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "deberta", None) is not None:
with tf.name_scope(self.deberta.name):
self.deberta.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.hidden_size])
@add_start_docstrings(
"""
@ -1364,6 +1565,7 @@ class TFDebertaForQuestionAnswering(TFDebertaPreTrainedModel, TFQuestionAnswerin
self.qa_outputs = tf.keras.layers.Dense(
units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ -1430,3 +1632,14 @@ class TFDebertaForQuestionAnswering(TFDebertaPreTrainedModel, TFQuestionAnswerin
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "deberta", None) is not None:
with tf.name_scope(self.deberta.name):
self.deberta.build(None)
if getattr(self, "qa_outputs", None) is not None:
with tf.name_scope(self.qa_outputs.name):
self.qa_outputs.build([None, None, self.config.hidden_size])

View File

@ -78,6 +78,17 @@ class TFDebertaV2ContextPooler(tf.keras.layers.Layer):
def output_dim(self) -> int:
return self.config.hidden_size
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.pooler_hidden_size])
if getattr(self, "dropout", None) is not None:
with tf.name_scope(self.dropout.name):
self.dropout.build(None)
# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaXSoftmax with Deberta->DebertaV2
class TFDebertaV2XSoftmax(tf.keras.layers.Layer):
@ -150,6 +161,7 @@ class TFDebertaV2SelfOutput(tf.keras.layers.Layer):
self.dense = tf.keras.layers.Dense(config.hidden_size, name="dense")
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = TFDebertaV2StableDropout(config.hidden_dropout_prob, name="dropout")
self.config = config
def call(self, hidden_states, input_tensor, training: bool = False):
hidden_states = self.dense(hidden_states)
@ -157,6 +169,20 @@ class TFDebertaV2SelfOutput(tf.keras.layers.Layer):
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
if getattr(self, "dropout", None) is not None:
with tf.name_scope(self.dropout.name):
self.dropout.build(None)
# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaAttention with Deberta->DebertaV2
class TFDebertaV2Attention(tf.keras.layers.Layer):
@ -195,6 +221,17 @@ class TFDebertaV2Attention(tf.keras.layers.Layer):
return output
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "self", None) is not None:
with tf.name_scope(self.self.name):
self.self.build(None)
if getattr(self, "dense_output", None) is not None:
with tf.name_scope(self.dense_output.name):
self.dense_output.build(None)
# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaIntermediate with Deberta->DebertaV2
class TFDebertaV2Intermediate(tf.keras.layers.Layer):
@ -209,6 +246,7 @@ class TFDebertaV2Intermediate(tf.keras.layers.Layer):
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
else:
self.intermediate_act_fn = config.hidden_act
self.config = config
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -216,6 +254,14 @@ class TFDebertaV2Intermediate(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaOutput with Deberta->DebertaV2
class TFDebertaV2Output(tf.keras.layers.Layer):
@ -227,6 +273,7 @@ class TFDebertaV2Output(tf.keras.layers.Layer):
)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = TFDebertaV2StableDropout(config.hidden_dropout_prob, name="dropout")
self.config = config
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -235,6 +282,20 @@ class TFDebertaV2Output(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.intermediate_size])
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
if getattr(self, "dropout", None) is not None:
with tf.name_scope(self.dropout.name):
self.dropout.build(None)
# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaLayer with Deberta->DebertaV2
class TFDebertaV2Layer(tf.keras.layers.Layer):
@ -273,6 +334,20 @@ class TFDebertaV2Layer(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "attention", None) is not None:
with tf.name_scope(self.attention.name):
self.attention.build(None)
if getattr(self, "intermediate", None) is not None:
with tf.name_scope(self.intermediate.name):
self.intermediate.build(None)
if getattr(self, "bert_output", None) is not None:
with tf.name_scope(self.bert_output.name):
self.bert_output.build(None)
class TFDebertaV2ConvLayer(tf.keras.layers.Layer):
def __init__(self, config: DebertaV2Config, **kwargs):
@ -286,7 +361,7 @@ class TFDebertaV2ConvLayer(tf.keras.layers.Layer):
self.dropout = TFDebertaV2StableDropout(config.hidden_dropout_prob, name="dropout")
self.config = config
def build(self, input_shape):
def build(self, input_shape=None):
with tf.name_scope("conv"):
self.conv_kernel = self.add_weight(
name="kernel",
@ -296,7 +371,16 @@ class TFDebertaV2ConvLayer(tf.keras.layers.Layer):
self.conv_bias = self.add_weight(
name="bias", shape=[self.config.hidden_size], initializer=tf.zeros_initializer()
)
return super().build(input_shape)
return
if self.built:
return
self.built = True
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build(None)
if getattr(self, "dropout", None) is not None:
with tf.name_scope(self.dropout.name):
self.dropout.build(None)
def call(
self, hidden_states: tf.Tensor, residual_states: tf.Tensor, input_mask: tf.Tensor, training: bool = False
@ -354,14 +438,26 @@ class TFDebertaV2Encoder(tf.keras.layers.Layer):
self.conv = TFDebertaV2ConvLayer(config, name="conv") if getattr(config, "conv_kernel_size", 0) > 0 else None
def build(self, input_shape):
def build(self, input_shape=None):
if self.built:
return
self.built = True
if self.relative_attention:
self.rel_embeddings = self.add_weight(
name="rel_embeddings.weight",
shape=[self.pos_ebd_size, self.config.hidden_size],
initializer=get_initializer(self.config.initializer_range),
)
return super().build(input_shape)
if getattr(self, "conv", None) is not None:
with tf.name_scope(self.conv.name):
self.conv.build(None)
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
if getattr(self, "layer", None) is not None:
for layer in self.layer:
with tf.name_scope(layer.name):
layer.build(None)
def get_rel_embedding(self):
rel_embeddings = self.rel_embeddings if self.relative_attention else None
@ -611,6 +707,7 @@ class TFDebertaV2DisentangledSelfAttention(tf.keras.layers.Layer):
)
self.softmax = TFDebertaV2XSoftmax(axis=-1)
self.dropout = TFDebertaV2StableDropout(config.attention_probs_dropout_prob, name="dropout")
self.config = config
def transpose_for_scores(self, tensor: tf.Tensor, attention_heads: int) -> tf.Tensor:
tensor_shape = shape_list(tensor)
@ -801,6 +898,32 @@ class TFDebertaV2DisentangledSelfAttention(tf.keras.layers.Layer):
return score
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "query_proj", None) is not None:
with tf.name_scope(self.query_proj.name):
self.query_proj.build([None, None, self.config.hidden_size])
if getattr(self, "key_proj", None) is not None:
with tf.name_scope(self.key_proj.name):
self.key_proj.build([None, None, self.config.hidden_size])
if getattr(self, "value_proj", None) is not None:
with tf.name_scope(self.value_proj.name):
self.value_proj.build([None, None, self.config.hidden_size])
if getattr(self, "dropout", None) is not None:
with tf.name_scope(self.dropout.name):
self.dropout.build(None)
if getattr(self, "pos_dropout", None) is not None:
with tf.name_scope(self.pos_dropout.name):
self.pos_dropout.build(None)
if getattr(self, "pos_key_proj", None) is not None:
with tf.name_scope(self.pos_key_proj.name):
self.pos_key_proj.build([None, None, self.config.hidden_size])
if getattr(self, "pos_query_proj", None) is not None:
with tf.name_scope(self.pos_query_proj.name):
self.pos_query_proj.build([None, None, self.config.hidden_size])
# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaEmbeddings Deberta->DebertaV2
class TFDebertaV2Embeddings(tf.keras.layers.Layer):
@ -825,7 +948,7 @@ class TFDebertaV2Embeddings(tf.keras.layers.Layer):
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = TFDebertaV2StableDropout(config.hidden_dropout_prob, name="dropout")
def build(self, input_shape: tf.TensorShape):
def build(self, input_shape=None):
with tf.name_scope("word_embeddings"):
self.weight = self.add_weight(
name="weight",
@ -853,7 +976,18 @@ class TFDebertaV2Embeddings(tf.keras.layers.Layer):
else:
self.position_embeddings = None
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
if getattr(self, "dropout", None) is not None:
with tf.name_scope(self.dropout.name):
self.dropout.build(None)
if getattr(self, "embed_proj", None) is not None:
with tf.name_scope(self.embed_proj.name):
self.embed_proj.build([None, None, self.embedding_size])
def call(
self,
@ -929,6 +1063,7 @@ class TFDebertaV2PredictionHeadTransform(tf.keras.layers.Layer):
else:
self.transform_act_fn = config.hidden_act
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.config = config
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -937,6 +1072,17 @@ class TFDebertaV2PredictionHeadTransform(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.embedding_size])
# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaLMPredictionHead with Deberta->DebertaV2
class TFDebertaV2LMPredictionHead(tf.keras.layers.Layer):
@ -952,10 +1098,15 @@ class TFDebertaV2LMPredictionHead(tf.keras.layers.Layer):
# an output-only bias for each token.
self.input_embeddings = input_embeddings
def build(self, input_shape: tf.TensorShape):
def build(self, input_shape=None):
self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias")
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "transform", None) is not None:
with tf.name_scope(self.transform.name):
self.transform.build(None)
def get_output_embeddings(self) -> tf.keras.layers.Layer:
return self.input_embeddings
@ -993,6 +1144,14 @@ class TFDebertaV2OnlyMLMHead(tf.keras.layers.Layer):
return prediction_scores
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "predictions", None) is not None:
with tf.name_scope(self.predictions.name):
self.predictions.build(None)
# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaMainLayer with Deberta->DebertaV2
class TFDebertaV2MainLayer(tf.keras.layers.Layer):
@ -1077,6 +1236,17 @@ class TFDebertaV2MainLayer(tf.keras.layers.Layer):
attentions=encoder_outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embeddings", None) is not None:
with tf.name_scope(self.embeddings.name):
self.embeddings.build(None)
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaPreTrainedModel with Deberta->DebertaV2
class TFDebertaV2PreTrainedModel(TFPreTrainedModel):
@ -1219,6 +1389,14 @@ class TFDebertaV2Model(TFDebertaV2PreTrainedModel):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "deberta", None) is not None:
with tf.name_scope(self.deberta.name):
self.deberta.build(None)
@add_start_docstrings("""DeBERTa Model with a `language modeling` head on top.""", DEBERTA_START_DOCSTRING)
# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaForMaskedLM with Deberta->DebertaV2
@ -1290,6 +1468,17 @@ class TFDebertaV2ForMaskedLM(TFDebertaV2PreTrainedModel, TFMaskedLanguageModelin
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "deberta", None) is not None:
with tf.name_scope(self.deberta.name):
self.deberta.build(None)
if getattr(self, "mlm", None) is not None:
with tf.name_scope(self.mlm.name):
self.mlm.build(None)
@add_start_docstrings(
"""
@ -1316,6 +1505,7 @@ class TFDebertaV2ForSequenceClassification(TFDebertaV2PreTrainedModel, TFSequenc
kernel_initializer=get_initializer(config.initializer_range),
name="classifier",
)
self.output_dim = self.pooler.output_dim
@unpack_inputs
@add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ -1372,6 +1562,23 @@ class TFDebertaV2ForSequenceClassification(TFDebertaV2PreTrainedModel, TFSequenc
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "deberta", None) is not None:
with tf.name_scope(self.deberta.name):
self.deberta.build(None)
if getattr(self, "pooler", None) is not None:
with tf.name_scope(self.pooler.name):
self.pooler.build(None)
if getattr(self, "dropout", None) is not None:
with tf.name_scope(self.dropout.name):
self.dropout.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.output_dim])
@add_start_docstrings(
"""
@ -1392,6 +1599,7 @@ class TFDebertaV2ForTokenClassification(TFDebertaV2PreTrainedModel, TFTokenClass
self.classifier = tf.keras.layers.Dense(
units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ -1444,6 +1652,17 @@ class TFDebertaV2ForTokenClassification(TFDebertaV2PreTrainedModel, TFTokenClass
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "deberta", None) is not None:
with tf.name_scope(self.deberta.name):
self.deberta.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.hidden_size])
@add_start_docstrings(
"""
@ -1463,6 +1682,7 @@ class TFDebertaV2ForQuestionAnswering(TFDebertaV2PreTrainedModel, TFQuestionAnsw
self.qa_outputs = tf.keras.layers.Dense(
units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ -1530,6 +1750,17 @@ class TFDebertaV2ForQuestionAnswering(TFDebertaV2PreTrainedModel, TFQuestionAnsw
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "deberta", None) is not None:
with tf.name_scope(self.deberta.name):
self.deberta.build(None)
if getattr(self, "qa_outputs", None) is not None:
with tf.name_scope(self.qa_outputs.name):
self.qa_outputs.build([None, None, self.config.hidden_size])
@add_start_docstrings(
"""
@ -1552,6 +1783,7 @@ class TFDebertaV2ForMultipleChoice(TFDebertaV2PreTrainedModel, TFMultipleChoiceL
self.classifier = tf.keras.layers.Dense(
units=1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
self.output_dim = self.pooler.output_dim
@unpack_inputs
@add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
@ -1628,3 +1860,17 @@ class TFDebertaV2ForMultipleChoice(TFDebertaV2PreTrainedModel, TFMultipleChoiceL
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "deberta", None) is not None:
with tf.name_scope(self.deberta.name):
self.deberta.build(None)
if getattr(self, "pooler", None) is not None:
with tf.name_scope(self.pooler.name):
self.pooler.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.output_dim])

View File

@ -113,7 +113,7 @@ class TFDeiTEmbeddings(tf.keras.layers.Layer):
self.patch_embeddings = TFDeiTPatchEmbeddings(config=config, name="patch_embeddings")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob, name="dropout")
def build(self, input_shape: tf.TensorShape):
def build(self, input_shape=None):
self.cls_token = self.add_weight(
shape=(1, 1, self.config.hidden_size),
initializer=tf.keras.initializers.zeros(),
@ -141,7 +141,16 @@ class TFDeiTEmbeddings(tf.keras.layers.Layer):
trainable=True,
name="position_embeddings",
)
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "patch_embeddings", None) is not None:
with tf.name_scope(self.patch_embeddings.name):
self.patch_embeddings.build(None)
if getattr(self, "dropout", None) is not None:
with tf.name_scope(self.dropout.name):
self.dropout.build(None)
def call(
self, pixel_values: tf.Tensor, bool_masked_pos: tf.Tensor | None = None, training: bool = False
@ -203,6 +212,14 @@ class TFDeiTPatchEmbeddings(tf.keras.layers.Layer):
x = tf.reshape(x, (batch_size, height * width, num_channels))
return x
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "projection", None) is not None:
with tf.name_scope(self.projection.name):
self.projection.build([None, None, None, self.num_channels])
# Copied from transformers.models.vit.modeling_tf_vit.TFViTSelfAttention with ViT->DeiT
class TFDeiTSelfAttention(tf.keras.layers.Layer):
@ -230,6 +247,7 @@ class TFDeiTSelfAttention(tf.keras.layers.Layer):
units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
)
self.dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
self.config = config
def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
@ -279,6 +297,20 @@ class TFDeiTSelfAttention(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "query", None) is not None:
with tf.name_scope(self.query.name):
self.query.build([None, None, self.config.hidden_size])
if getattr(self, "key", None) is not None:
with tf.name_scope(self.key.name):
self.key.build([None, None, self.config.hidden_size])
if getattr(self, "value", None) is not None:
with tf.name_scope(self.value.name):
self.value.build([None, None, self.config.hidden_size])
# Copied from transformers.models.vit.modeling_tf_vit.TFViTSelfOutput with ViT->DeiT
class TFDeiTSelfOutput(tf.keras.layers.Layer):
@ -294,6 +326,7 @@ class TFDeiTSelfOutput(tf.keras.layers.Layer):
units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
self.config = config
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -301,6 +334,14 @@ class TFDeiTSelfOutput(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
# Copied from transformers.models.vit.modeling_tf_vit.TFViTAttention with ViT->DeiT
class TFDeiTAttention(tf.keras.layers.Layer):
@ -330,6 +371,17 @@ class TFDeiTAttention(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "self_attention", None) is not None:
with tf.name_scope(self.self_attention.name):
self.self_attention.build(None)
if getattr(self, "dense_output", None) is not None:
with tf.name_scope(self.dense_output.name):
self.dense_output.build(None)
# Copied from transformers.models.vit.modeling_tf_vit.TFViTIntermediate with ViT->DeiT
class TFDeiTIntermediate(tf.keras.layers.Layer):
@ -344,6 +396,7 @@ class TFDeiTIntermediate(tf.keras.layers.Layer):
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
else:
self.intermediate_act_fn = config.hidden_act
self.config = config
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -351,6 +404,14 @@ class TFDeiTIntermediate(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
# Copied from transformers.models.vit.modeling_tf_vit.TFViTOutput with ViT->DeiT
class TFDeiTOutput(tf.keras.layers.Layer):
@ -361,6 +422,7 @@ class TFDeiTOutput(tf.keras.layers.Layer):
units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
self.config = config
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -369,6 +431,14 @@ class TFDeiTOutput(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.intermediate_size])
class TFDeiTLayer(tf.keras.layers.Layer):
"""This corresponds to the Block class in the timm implementation."""
@ -386,6 +456,7 @@ class TFDeiTLayer(tf.keras.layers.Layer):
self.layernorm_after = tf.keras.layers.LayerNormalization(
epsilon=config.layer_norm_eps, name="layernorm_after"
)
self.config = config
def call(
self,
@ -419,6 +490,26 @@ class TFDeiTLayer(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "attention", None) is not None:
with tf.name_scope(self.attention.name):
self.attention.build(None)
if getattr(self, "intermediate", None) is not None:
with tf.name_scope(self.intermediate.name):
self.intermediate.build(None)
if getattr(self, "deit_output", None) is not None:
with tf.name_scope(self.deit_output.name):
self.deit_output.build(None)
if getattr(self, "layernorm_before", None) is not None:
with tf.name_scope(self.layernorm_before.name):
self.layernorm_before.build([None, None, self.config.hidden_size])
if getattr(self, "layernorm_after", None) is not None:
with tf.name_scope(self.layernorm_after.name):
self.layernorm_after.build([None, None, self.config.hidden_size])
# Copied from transformers.models.vit.modeling_tf_vit.TFViTEncoder with ViT->DeiT
class TFDeiTEncoder(tf.keras.layers.Layer):
@ -465,6 +556,15 @@ class TFDeiTEncoder(tf.keras.layers.Layer):
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "layer", None) is not None:
for layer in self.layer:
with tf.name_scope(layer.name):
layer.build(None)
@keras_serializable
class TFDeiTMainLayer(tf.keras.layers.Layer):
@ -556,6 +656,23 @@ class TFDeiTMainLayer(tf.keras.layers.Layer):
attentions=encoder_outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embeddings", None) is not None:
with tf.name_scope(self.embeddings.name):
self.embeddings.build(None)
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
if getattr(self, "layernorm", None) is not None:
with tf.name_scope(self.layernorm.name):
self.layernorm.build([None, None, self.config.hidden_size])
if getattr(self, "pooler", None) is not None:
with tf.name_scope(self.pooler.name):
self.pooler.build(None)
# Copied from transformers.models.vit.modeling_tf_vit.TFViTPreTrainedModel with ViT->DeiT all-casing
class TFDeiTPreTrainedModel(TFPreTrainedModel):
@ -647,6 +764,14 @@ class TFDeiTModel(TFDeiTPreTrainedModel):
)
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "deit", None) is not None:
with tf.name_scope(self.deit.name):
self.deit.build(None)
# Copied from transformers.models.vit.modeling_tf_vit.TFViTPooler with ViT->DeiT
class TFDeiTPooler(tf.keras.layers.Layer):
@ -659,6 +784,7 @@ class TFDeiTPooler(tf.keras.layers.Layer):
activation="tanh",
name="dense",
)
self.config = config
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
# We "pool" the model by simply taking the hidden state corresponding
@ -668,6 +794,14 @@ class TFDeiTPooler(tf.keras.layers.Layer):
return pooled_output
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
class TFDeitPixelShuffle(tf.keras.layers.Layer):
"""TF layer implementation of torch.nn.PixelShuffle"""
@ -702,6 +836,7 @@ class TFDeitDecoder(tf.keras.layers.Layer):
filters=config.encoder_stride**2 * config.num_channels, kernel_size=1, name="0"
)
self.pixel_shuffle = TFDeitPixelShuffle(config.encoder_stride, name="1")
self.config = config
def call(self, inputs: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = inputs
@ -709,6 +844,17 @@ class TFDeitDecoder(tf.keras.layers.Layer):
hidden_states = self.pixel_shuffle(hidden_states)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "conv2d", None) is not None:
with tf.name_scope(self.conv2d.name):
self.conv2d.build([None, None, None, self.config.hidden_size])
if getattr(self, "pixel_shuffle", None) is not None:
with tf.name_scope(self.pixel_shuffle.name):
self.pixel_shuffle.build(None)
@add_start_docstrings(
"DeiT Model with a decoder on top for masked image modeling, as proposed in"
@ -822,6 +968,17 @@ class TFDeiTForMaskedImageModeling(TFDeiTPreTrainedModel):
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "deit", None) is not None:
with tf.name_scope(self.deit.name):
self.deit.build(None)
if getattr(self, "decoder", None) is not None:
with tf.name_scope(self.decoder.name):
self.decoder.build(None)
@add_start_docstrings(
"""
@ -843,6 +1000,7 @@ class TFDeiTForImageClassification(TFDeiTPreTrainedModel, TFSequenceClassificati
if config.num_labels > 0
else tf.keras.layers.Activation("linear", name="classifier")
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(DEIT_INPUTS_DOCSTRING)
@ -919,6 +1077,17 @@ class TFDeiTForImageClassification(TFDeiTPreTrainedModel, TFSequenceClassificati
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "deit", None) is not None:
with tf.name_scope(self.deit.name):
self.deit.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.hidden_size])
@add_start_docstrings(
"""
@ -950,6 +1119,7 @@ class TFDeiTForImageClassificationWithTeacher(TFDeiTPreTrainedModel):
if config.num_labels > 0
else tf.keras.layers.Activation("linear", name="distillation_classifier")
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(DEIT_INPUTS_DOCSTRING)
@ -998,3 +1168,17 @@ class TFDeiTForImageClassificationWithTeacher(TFDeiTPreTrainedModel):
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "deit", None) is not None:
with tf.name_scope(self.deit.name):
self.deit.build(None)
if getattr(self, "cls_classifier", None) is not None:
with tf.name_scope(self.cls_classifier.name):
self.cls_classifier.build([None, None, self.config.hidden_size])
if getattr(self, "distillation_classifier", None) is not None:
with tf.name_scope(self.distillation_classifier.name):
self.distillation_classifier.build([None, None, self.config.hidden_size])

View File

@ -84,7 +84,7 @@ class TFEmbeddings(tf.keras.layers.Layer):
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=1e-12, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.dropout)
def build(self, input_shape: tf.TensorShape):
def build(self, input_shape=None):
with tf.name_scope("word_embeddings"):
self.weight = self.add_weight(
name="weight",
@ -99,7 +99,12 @@ class TFEmbeddings(tf.keras.layers.Layer):
initializer=get_initializer(initializer_range=self.initializer_range),
)
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.dim])
def call(self, input_ids=None, position_ids=None, inputs_embeds=None, training=False):
"""
@ -152,6 +157,7 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer):
)
self.pruned_heads = set()
self.config = config
def prune_heads(self, heads):
raise NotImplementedError
@ -212,6 +218,23 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer):
else:
return (context,)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "q_lin", None) is not None:
with tf.name_scope(self.q_lin.name):
self.q_lin.build([None, None, self.config.dim])
if getattr(self, "k_lin", None) is not None:
with tf.name_scope(self.k_lin.name):
self.k_lin.build([None, None, self.config.dim])
if getattr(self, "v_lin", None) is not None:
with tf.name_scope(self.v_lin.name):
self.v_lin.build([None, None, self.config.dim])
if getattr(self, "out_lin", None) is not None:
with tf.name_scope(self.out_lin.name):
self.out_lin.build([None, None, self.config.dim])
class TFFFN(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
@ -224,6 +247,7 @@ class TFFFN(tf.keras.layers.Layer):
config.dim, kernel_initializer=get_initializer(config.initializer_range), name="lin2"
)
self.activation = get_tf_activation(config.activation)
self.config = config
def call(self, input, training=False):
x = self.lin1(input)
@ -232,6 +256,17 @@ class TFFFN(tf.keras.layers.Layer):
x = self.dropout(x, training=training)
return x
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "lin1", None) is not None:
with tf.name_scope(self.lin1.name):
self.lin1.build([None, None, self.config.dim])
if getattr(self, "lin2", None) is not None:
with tf.name_scope(self.lin2.name):
self.lin2.build([None, None, self.config.hidden_dim])
class TFTransformerBlock(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
@ -253,6 +288,7 @@ class TFTransformerBlock(tf.keras.layers.Layer):
self.ffn = TFFFN(config, name="ffn")
self.output_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-12, name="output_layer_norm")
self.config = config
def call(self, x, attn_mask, head_mask, output_attentions, training=False): # removed: src_enc=None, src_len=None
"""
@ -281,6 +317,23 @@ class TFTransformerBlock(tf.keras.layers.Layer):
output = (sa_weights,) + output
return output
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "attention", None) is not None:
with tf.name_scope(self.attention.name):
self.attention.build(None)
if getattr(self, "sa_layer_norm", None) is not None:
with tf.name_scope(self.sa_layer_norm.name):
self.sa_layer_norm.build([None, None, self.config.dim])
if getattr(self, "ffn", None) is not None:
with tf.name_scope(self.ffn.name):
self.ffn.build(None)
if getattr(self, "output_layer_norm", None) is not None:
with tf.name_scope(self.output_layer_norm.name):
self.output_layer_norm.build([None, None, self.config.dim])
class TFTransformer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
@ -336,6 +389,15 @@ class TFTransformer(tf.keras.layers.Layer):
last_hidden_state=hidden_state, hidden_states=all_hidden_states, attentions=all_attentions
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "layer", None) is not None:
for layer in self.layer:
with tf.name_scope(layer.name):
layer.build(None)
@keras_serializable
class TFDistilBertMainLayer(tf.keras.layers.Layer):
@ -412,6 +474,17 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
return tfmr_output # last-layer hidden-state, (all hidden_states), (all attentions)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embeddings", None) is not None:
with tf.name_scope(self.embeddings.name):
self.embeddings.build(None)
if getattr(self, "transformer", None) is not None:
with tf.name_scope(self.transformer.name):
self.transformer.build(None)
# INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL #
class TFDistilBertPreTrainedModel(TFPreTrainedModel):
@ -548,6 +621,14 @@ class TFDistilBertModel(TFDistilBertPreTrainedModel):
)
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "distilbert", None) is not None:
with tf.name_scope(self.distilbert.name):
self.distilbert.build(None)
class TFDistilBertLMHead(tf.keras.layers.Layer):
def __init__(self, config, input_embeddings, **kwargs):
@ -667,6 +748,23 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel, TFMaskedLanguageModel
attentions=distilbert_output.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "distilbert", None) is not None:
with tf.name_scope(self.distilbert.name):
self.distilbert.build(None)
if getattr(self, "vocab_transform", None) is not None:
with tf.name_scope(self.vocab_transform.name):
self.vocab_transform.build([None, None, self.config.dim])
if getattr(self, "vocab_layer_norm", None) is not None:
with tf.name_scope(self.vocab_layer_norm.name):
self.vocab_layer_norm.build([None, None, self.config.dim])
if getattr(self, "vocab_projector", None) is not None:
with tf.name_scope(self.vocab_projector.name):
self.vocab_projector.build(None)
@add_start_docstrings(
"""
@ -691,6 +789,7 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel, TFSeque
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
self.dropout = tf.keras.layers.Dropout(config.seq_classif_dropout)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ -746,6 +845,20 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel, TFSeque
attentions=distilbert_output.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "distilbert", None) is not None:
with tf.name_scope(self.distilbert.name):
self.distilbert.build(None)
if getattr(self, "pre_classifier", None) is not None:
with tf.name_scope(self.pre_classifier.name):
self.pre_classifier.build([None, None, self.config.dim])
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.dim])
@add_start_docstrings(
"""
@ -764,6 +877,7 @@ class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel, TFTokenCla
self.classifier = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ -814,6 +928,17 @@ class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel, TFTokenCla
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "distilbert", None) is not None:
with tf.name_scope(self.distilbert.name):
self.distilbert.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.hidden_size])
@add_start_docstrings(
"""
@ -837,6 +962,7 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic
self.classifier = tf.keras.layers.Dense(
1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(
@ -908,6 +1034,20 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic
attentions=distilbert_output.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "distilbert", None) is not None:
with tf.name_scope(self.distilbert.name):
self.distilbert.build(None)
if getattr(self, "pre_classifier", None) is not None:
with tf.name_scope(self.pre_classifier.name):
self.pre_classifier.build([None, None, self.config.dim])
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.dim])
@add_start_docstrings(
"""
@ -926,6 +1066,7 @@ class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel, TFQuestionAn
)
assert config.num_labels == 2, f"Incorrect number of labels {config.num_labels} instead of 2"
self.dropout = tf.keras.layers.Dropout(config.qa_dropout)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ -991,3 +1132,14 @@ class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel, TFQuestionAn
hidden_states=distilbert_output.hidden_states,
attentions=distilbert_output.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "distilbert", None) is not None:
with tf.name_scope(self.distilbert.name):
self.distilbert.build(None)
if getattr(self, "qa_outputs", None) is not None:
with tf.name_scope(self.qa_outputs.name):
self.qa_outputs.build([None, None, self.config.dim])

View File

@ -209,6 +209,17 @@ class TFDPREncoderLayer(tf.keras.layers.Layer):
return self.projection_dim
return self.bert_model.config.hidden_size
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "bert_model", None) is not None:
with tf.name_scope(self.bert_model.name):
self.bert_model.build(None)
if getattr(self, "encode_proj", None) is not None:
with tf.name_scope(self.encode_proj.name):
self.encode_proj.build(None)
class TFDPRSpanPredictorLayer(tf.keras.layers.Layer):
base_model_prefix = "encoder"
@ -273,6 +284,20 @@ class TFDPRSpanPredictorLayer(tf.keras.layers.Layer):
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
if getattr(self, "qa_outputs", None) is not None:
with tf.name_scope(self.qa_outputs.name):
self.qa_outputs.build([None, None, self.encoder.embeddings_size])
if getattr(self, "qa_classifier", None) is not None:
with tf.name_scope(self.qa_classifier.name):
self.qa_classifier.build([None, None, self.encoder.embeddings_size])
class TFDPRSpanPredictor(TFPreTrainedModel):
base_model_prefix = "encoder"
@ -599,6 +624,14 @@ class TFDPRContextEncoder(TFDPRPretrainedContextEncoder):
pooler_output=outputs.pooler_output, hidden_states=outputs.hidden_states, attentions=outputs.attentions
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "ctx_encoder", None) is not None:
with tf.name_scope(self.ctx_encoder.name):
self.ctx_encoder.build(None)
@add_start_docstrings(
"The bare DPRQuestionEncoder transformer outputting pooler outputs as question representations.",
@ -679,6 +712,14 @@ class TFDPRQuestionEncoder(TFDPRPretrainedQuestionEncoder):
pooler_output=outputs.pooler_output, hidden_states=outputs.hidden_states, attentions=outputs.attentions
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "question_encoder", None) is not None:
with tf.name_scope(self.question_encoder.name):
self.question_encoder.build(None)
@add_start_docstrings(
"The bare DPRReader transformer outputting span predictions.",
@ -752,3 +793,11 @@ class TFDPRReader(TFDPRPretrainedReader):
return_dict=return_dict,
training=training,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "span_predictor", None) is not None:
with tf.name_scope(self.span_predictor.name):
self.span_predictor.build(None)

View File

@ -90,6 +90,7 @@ class TFEfficientFormerPatchEmbeddings(tf.keras.layers.Layer):
if apply_norm
else tf.identity
)
self.embed_dim = embed_dim
def call(self, pixel_values: tf.Tensor, training: bool = False) -> tf.Tensor:
tf.debugging.assert_shapes(
@ -100,6 +101,18 @@ class TFEfficientFormerPatchEmbeddings(tf.keras.layers.Layer):
embeddings = self.norm(embeddings, training=training)
return embeddings
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "projection", None) is not None:
with tf.name_scope(self.projection.name):
self.projection.build([None, None, None, self.num_channels])
if getattr(self, "norm", None) is not None:
if hasattr(self.norm, "name"):
with tf.name_scope(self.norm.name):
self.norm.build([None, None, None, self.embed_dim])
class TFEfficientFormerSelfAttention(tf.keras.layers.Layer):
def __init__(
@ -130,6 +143,7 @@ class TFEfficientFormerSelfAttention(tf.keras.layers.Layer):
units=dim, kernel_initializer=get_initializer(config.initializer_range), name="projection"
)
self.resolution = resolution
self.dim = dim
def build(self, input_shape: tf.TensorShape) -> None:
points = list(itertools.product(range(self.resolution), range(self.resolution)))
@ -160,7 +174,15 @@ class TFEfficientFormerSelfAttention(tf.keras.layers.Layer):
self.attention_bias_idxs.assign(tf.reshape(tf.cast(idxs, dtype=tf.int32), (num_points, num_points)))
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "qkv", None) is not None:
with tf.name_scope(self.qkv.name):
self.qkv.build([None, None, self.dim])
if getattr(self, "projection", None) is not None:
with tf.name_scope(self.projection.name):
self.projection.build([None, None, self.total_expanded_key_dim])
def call(
self, hidden_states: tf.Tensor, output_attentions: bool = False, training: bool = False
@ -225,6 +247,8 @@ class TFEfficientFormerConvStem(tf.keras.layers.Layer):
)
self.activation = tf.keras.layers.Activation(activation=tf.keras.activations.relu, name="activation")
self.out_channels = out_channels
self.config = config
def call(self, pixel_values: tf.Tensor, training: bool = False) -> tf.Tensor:
features = self.batchnorm_before(self.convolution1(self.padding(pixel_values)), training=training)
@ -233,6 +257,26 @@ class TFEfficientFormerConvStem(tf.keras.layers.Layer):
features = self.activation(features)
return features
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "convolution1", None) is not None:
with tf.name_scope(self.convolution1.name):
self.convolution1.build([None, None, None, self.config.num_channels])
if getattr(self, "batchnorm_before", None) is not None:
with tf.name_scope(self.batchnorm_before.name):
self.batchnorm_before.build([None, None, None, self.out_channels // 2])
if getattr(self, "convolution2", None) is not None:
with tf.name_scope(self.convolution2.name):
self.convolution2.build([None, None, None, self.out_channels // 2])
if getattr(self, "batchnorm_after", None) is not None:
with tf.name_scope(self.batchnorm_after.name):
self.batchnorm_after.build([None, None, None, self.out_channels])
if getattr(self, "activation", None) is not None:
with tf.name_scope(self.activation.name):
self.activation.build(None)
class TFEfficientFormerPooling(tf.keras.layers.Layer):
def __init__(self, pool_size: int, **kwargs):
@ -267,6 +311,8 @@ class TFEfficientFormerDenseMlp(tf.keras.layers.Layer):
self.linear_out = tf.keras.layers.Dense(
units=out_features, kernel_initializer=get_initializer(config.initializer_range), name="linear_out"
)
self.hidden_features = hidden_features
self.in_features = in_features
def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.linear_in(inputs=hidden_states)
@ -277,6 +323,17 @@ class TFEfficientFormerDenseMlp(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "linear_in", None) is not None:
with tf.name_scope(self.linear_in.name):
self.linear_in.build([None, None, self.in_features])
if getattr(self, "linear_out", None) is not None:
with tf.name_scope(self.linear_out.name):
self.linear_out.build([None, None, self.hidden_features])
class TFEfficientFormerConvMlp(tf.keras.layers.Layer):
def __init__(
@ -318,6 +375,9 @@ class TFEfficientFormerConvMlp(tf.keras.layers.Layer):
self.batchnorm_after = tf.keras.layers.BatchNormalization(
axis=-1, epsilon=config.batch_norm_eps, momentum=0.9, name="batchnorm_after"
)
self.hidden_features = hidden_features
self.in_features = in_features
self.out_features = out_features
def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_state = self.convolution1(hidden_state)
@ -329,6 +389,23 @@ class TFEfficientFormerConvMlp(tf.keras.layers.Layer):
hidden_state = self.dropout(hidden_state, training=training)
return hidden_state
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "convolution1", None) is not None:
with tf.name_scope(self.convolution1.name):
self.convolution1.build([None, None, None, self.in_features])
if getattr(self, "convolution2", None) is not None:
with tf.name_scope(self.convolution2.name):
self.convolution2.build([None, None, None, self.hidden_features])
if getattr(self, "batchnorm_before", None) is not None:
with tf.name_scope(self.batchnorm_before.name):
self.batchnorm_before.build([None, None, None, self.hidden_features])
if getattr(self, "batchnorm_after", None) is not None:
with tf.name_scope(self.batchnorm_after.name):
self.batchnorm_after.build([None, None, None, self.out_features])
# Copied from transformers.models.convnext.modeling_tf_convnext.TFConvNextDropPath with ConvNext->EfficientFormer
class TFEfficientFormerDropPath(tf.keras.layers.Layer):
@ -390,7 +467,7 @@ class TFEfficientFormerMeta3D(tf.keras.layers.Layer):
)
self.config = config
def build(self, input_shape: tf.TensorShape):
def build(self, input_shape=None):
self.layer_scale_1 = None
self.layer_scale_2 = None
@ -407,7 +484,25 @@ class TFEfficientFormerMeta3D(tf.keras.layers.Layer):
trainable=True,
name="layer_scale_2",
)
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "token_mixer", None) is not None:
with tf.name_scope(self.token_mixer.name):
self.token_mixer.build(None)
if getattr(self, "layernorm1", None) is not None:
with tf.name_scope(self.layernorm1.name):
self.layernorm1.build([None, None, self.dim])
if getattr(self, "layernorm2", None) is not None:
with tf.name_scope(self.layernorm2.name):
self.layernorm2.build([None, None, self.dim])
if getattr(self, "mlp", None) is not None:
with tf.name_scope(self.mlp.name):
self.mlp.build(None)
if getattr(self, "drop_path", None) is not None:
with tf.name_scope(self.drop_path.name):
self.drop_path.build(None)
def call(
self, hidden_states: tf.Tensor, output_attentions: bool = False, training: bool = False
@ -476,6 +571,15 @@ class TFEfficientFormerMeta3DLayers(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "blocks", None) is not None:
for layer in self.blocks:
with tf.name_scope(layer.name):
layer.build(None)
class TFEfficientFormerMeta4D(tf.keras.layers.Layer):
def __init__(self, config: EfficientFormerConfig, dim: int, drop_path: float = 0.0, **kwargs):
@ -495,7 +599,7 @@ class TFEfficientFormerMeta4D(tf.keras.layers.Layer):
)
self.config = config
def build(self, input_shape: tf.TensorShape):
def build(self, input_shape=None):
self.layer_scale_1 = None
self.layer_scale_2 = None
@ -512,7 +616,19 @@ class TFEfficientFormerMeta4D(tf.keras.layers.Layer):
trainable=True,
name="layer_scale_2",
)
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "token_mixer", None) is not None:
with tf.name_scope(self.token_mixer.name):
self.token_mixer.build(None)
if getattr(self, "mlp", None) is not None:
with tf.name_scope(self.mlp.name):
self.mlp.build(None)
if getattr(self, "drop_path", None) is not None:
with tf.name_scope(self.drop_path.name):
self.drop_path.build(None)
def call(self, hidden_states: tf.Tensor, training: bool = False) -> Tuple[tf.Tensor]:
outputs = self.token_mixer(hidden_states)
@ -560,6 +676,15 @@ class TFEfficientFormerMeta4DLayers(tf.keras.layers.Layer):
hidden_states = layer_module(hidden_states=hidden_states, training=training)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "blocks", None) is not None:
for layer in self.blocks:
with tf.name_scope(layer.name):
layer.build(None)
class TFEfficientFormerIntermediateStage(tf.keras.layers.Layer):
def __init__(self, config: EfficientFormerConfig, index: int, **kwargs):
@ -570,6 +695,14 @@ class TFEfficientFormerIntermediateStage(tf.keras.layers.Layer):
hidden_states = self.meta4D_layers(hidden_states=hidden_states, training=training)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "meta4D_layers", None) is not None:
with tf.name_scope(self.meta4D_layers.name):
self.meta4D_layers.build(None)
class TFEfficientFormerLastStage(tf.keras.layers.Layer):
def __init__(self, config: EfficientFormerConfig, **kwargs):
@ -589,6 +722,20 @@ class TFEfficientFormerLastStage(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "meta4D_layers", None) is not None:
with tf.name_scope(self.meta4D_layers.name):
self.meta4D_layers.build(None)
if getattr(self, "flat", None) is not None:
with tf.name_scope(self.flat.name):
self.flat.build(None)
if getattr(self, "meta3D_layers", None) is not None:
with tf.name_scope(self.meta3D_layers.name):
self.meta3D_layers.build(None)
class TFEfficientFormerEncoder(tf.keras.layers.Layer):
def __init__(self, config: EfficientFormerConfig, **kwargs):
@ -658,6 +805,17 @@ class TFEfficientFormerEncoder(tf.keras.layers.Layer):
attentions=all_self_attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "last_stage", None) is not None:
with tf.name_scope(self.last_stage.name):
self.last_stage.build(None)
for layer in self.intermediate_stages:
with tf.name_scope(layer.name):
layer.build(None)
@keras_serializable
class TFEfficientFormerMainLayer(tf.keras.layers.Layer):
@ -728,6 +886,20 @@ class TFEfficientFormerMainLayer(tf.keras.layers.Layer):
attentions=encoder_outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "patch_embed", None) is not None:
with tf.name_scope(self.patch_embed.name):
self.patch_embed.build(None)
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
if getattr(self, "layernorm", None) is not None:
with tf.name_scope(self.layernorm.name):
self.layernorm.build([None, None, self.config.hidden_sizes[-1]])
class TFEfficientFormerPreTrainedModel(TFPreTrainedModel):
"""
@ -804,6 +976,14 @@ class TFEfficientFormerModel(TFEfficientFormerPreTrainedModel):
)
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "efficientformer", None) is not None:
with tf.name_scope(self.efficientformer.name):
self.efficientformer.build(None)
@add_start_docstrings(
"""
@ -825,6 +1005,7 @@ class TFEfficientFormerForImageClassification(TFEfficientFormerPreTrainedModel,
if config.num_labels > 0
else tf.keras.layers.Activation("linear", name="classifier")
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(EFFICIENTFORMER_INPUTS_DOCSTRING)
@ -873,6 +1054,18 @@ class TFEfficientFormerForImageClassification(TFEfficientFormerPreTrainedModel,
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "efficientformer", None) is not None:
with tf.name_scope(self.efficientformer.name):
self.efficientformer.build(None)
if getattr(self, "classifier", None) is not None:
if hasattr(self.classifier, "name"):
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.hidden_sizes[-1]])
@dataclass
class TFEfficientFormerForImageClassificationWithTeacherOutput(ModelOutput):
@ -984,3 +1177,19 @@ class TFEfficientFormerForImageClassificationWithTeacher(TFEfficientFormerPreTra
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "efficientformer", None) is not None:
with tf.name_scope(self.efficientformer.name):
self.efficientformer.build(None)
if getattr(self, "classifier", None) is not None:
if hasattr(self.classifier, "name"):
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.hidden_sizes[-1]])
if getattr(self, "distillation_classifier", None) is not None:
if hasattr(self.distillation_classifier, "name"):
with tf.name_scope(self.distillation_classifier.name):
self.distillation_classifier.build([None, None, self.config.hidden_sizes[-1]])

View File

@ -103,6 +103,7 @@ class TFElectraSelfAttention(tf.keras.layers.Layer):
self.dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
self.is_decoder = config.is_decoder
self.config = config
def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
@ -192,6 +193,20 @@ class TFElectraSelfAttention(tf.keras.layers.Layer):
outputs = outputs + (past_key_value,)
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "query", None) is not None:
with tf.name_scope(self.query.name):
self.query.build([None, None, self.config.hidden_size])
if getattr(self, "key", None) is not None:
with tf.name_scope(self.key.name):
self.key.build([None, None, self.config.hidden_size])
if getattr(self, "value", None) is not None:
with tf.name_scope(self.value.name):
self.value.build([None, None, self.config.hidden_size])
# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfOutput with Bert->Electra
class TFElectraSelfOutput(tf.keras.layers.Layer):
@ -203,6 +218,7 @@ class TFElectraSelfOutput(tf.keras.layers.Layer):
)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
self.config = config
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -211,6 +227,17 @@ class TFElectraSelfOutput(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
# Copied from transformers.models.bert.modeling_tf_bert.TFBertAttention with Bert->Electra
class TFElectraAttention(tf.keras.layers.Layer):
@ -252,6 +279,17 @@ class TFElectraAttention(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "self_attention", None) is not None:
with tf.name_scope(self.self_attention.name):
self.self_attention.build(None)
if getattr(self, "dense_output", None) is not None:
with tf.name_scope(self.dense_output.name):
self.dense_output.build(None)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertIntermediate with Bert->Electra
class TFElectraIntermediate(tf.keras.layers.Layer):
@ -266,6 +304,7 @@ class TFElectraIntermediate(tf.keras.layers.Layer):
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
else:
self.intermediate_act_fn = config.hidden_act
self.config = config
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -273,6 +312,14 @@ class TFElectraIntermediate(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
# Copied from transformers.models.bert.modeling_tf_bert.TFBertOutput with Bert->Electra
class TFElectraOutput(tf.keras.layers.Layer):
@ -284,6 +331,7 @@ class TFElectraOutput(tf.keras.layers.Layer):
)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
self.config = config
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -292,6 +340,17 @@ class TFElectraOutput(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.intermediate_size])
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLayer with Bert->Electra
class TFElectraLayer(tf.keras.layers.Layer):
@ -379,6 +438,23 @@ class TFElectraLayer(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "attention", None) is not None:
with tf.name_scope(self.attention.name):
self.attention.build(None)
if getattr(self, "intermediate", None) is not None:
with tf.name_scope(self.intermediate.name):
self.intermediate.build(None)
if getattr(self, "bert_output", None) is not None:
with tf.name_scope(self.bert_output.name):
self.bert_output.build(None)
if getattr(self, "crossattention", None) is not None:
with tf.name_scope(self.crossattention.name):
self.crossattention.build(None)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertEncoder with Bert->Electra
class TFElectraEncoder(tf.keras.layers.Layer):
@ -449,6 +525,15 @@ class TFElectraEncoder(tf.keras.layers.Layer):
cross_attentions=all_cross_attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "layer", None) is not None:
for layer in self.layer:
with tf.name_scope(layer.name):
layer.build(None)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertPooler with Bert->Electra
class TFElectraPooler(tf.keras.layers.Layer):
@ -461,6 +546,7 @@ class TFElectraPooler(tf.keras.layers.Layer):
activation="tanh",
name="dense",
)
self.config = config
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
# We "pool" the model by simply taking the hidden state corresponding
@ -470,6 +556,14 @@ class TFElectraPooler(tf.keras.layers.Layer):
return pooled_output
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
# Copied from transformers.models.albert.modeling_tf_albert.TFAlbertEmbeddings with Albert->Electra
class TFElectraEmbeddings(tf.keras.layers.Layer):
@ -485,7 +579,7 @@ class TFElectraEmbeddings(tf.keras.layers.Layer):
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
def build(self, input_shape: tf.TensorShape):
def build(self, input_shape=None):
with tf.name_scope("word_embeddings"):
self.weight = self.add_weight(
name="weight",
@ -507,7 +601,12 @@ class TFElectraEmbeddings(tf.keras.layers.Layer):
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.embedding_size])
# Copied from transformers.models.bert.modeling_tf_bert.TFBertEmbeddings.call
def call(
@ -566,6 +665,17 @@ class TFElectraDiscriminatorPredictions(tf.keras.layers.Layer):
return logits
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
if getattr(self, "dense_prediction", None) is not None:
with tf.name_scope(self.dense_prediction.name):
self.dense_prediction.build([None, None, self.config.hidden_size])
class TFElectraGeneratorPredictions(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
@ -573,6 +683,7 @@ class TFElectraGeneratorPredictions(tf.keras.layers.Layer):
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dense = tf.keras.layers.Dense(config.embedding_size, name="dense")
self.config = config
def call(self, generator_hidden_states, training=False):
hidden_states = self.dense(generator_hidden_states)
@ -581,6 +692,17 @@ class TFElectraGeneratorPredictions(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.embedding_size])
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
class TFElectraPreTrainedModel(TFPreTrainedModel):
"""
@ -781,6 +903,20 @@ class TFElectraMainLayer(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embeddings", None) is not None:
with tf.name_scope(self.embeddings.name):
self.embeddings.build(None)
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
if getattr(self, "embeddings_project", None) is not None:
with tf.name_scope(self.embeddings_project.name):
self.embeddings_project.build([None, None, self.config.embedding_size])
@dataclass
class TFElectraForPreTrainingOutput(ModelOutput):
@ -977,6 +1113,14 @@ class TFElectraModel(TFElectraPreTrainedModel):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "electra", None) is not None:
with tf.name_scope(self.electra.name):
self.electra.build(None)
@add_start_docstrings(
"""
@ -1049,6 +1193,17 @@ class TFElectraForPreTraining(TFElectraPreTrainedModel):
attentions=discriminator_hidden_states.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "electra", None) is not None:
with tf.name_scope(self.electra.name):
self.electra.build(None)
if getattr(self, "discriminator_predictions", None) is not None:
with tf.name_scope(self.discriminator_predictions.name):
self.discriminator_predictions.build(None)
class TFElectraMaskedLMHead(tf.keras.layers.Layer):
def __init__(self, config, input_embeddings, **kwargs):
@ -1177,6 +1332,20 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel, TFMaskedLanguageModelingLos
attentions=generator_hidden_states.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "electra", None) is not None:
with tf.name_scope(self.electra.name):
self.electra.build(None)
if getattr(self, "generator_predictions", None) is not None:
with tf.name_scope(self.generator_predictions.name):
self.generator_predictions.build(None)
if getattr(self, "generator_lm_head", None) is not None:
with tf.name_scope(self.generator_lm_head.name):
self.generator_lm_head.build(None)
class TFElectraClassificationHead(tf.keras.layers.Layer):
"""Head for sentence-level classification tasks."""
@ -1196,6 +1365,7 @@ class TFElectraClassificationHead(tf.keras.layers.Layer):
self.out_proj = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="out_proj"
)
self.config = config
def call(self, inputs, **kwargs):
x = inputs[:, 0, :] # take <s> token (equiv. to [CLS])
@ -1207,6 +1377,17 @@ class TFElectraClassificationHead(tf.keras.layers.Layer):
return x
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
if getattr(self, "out_proj", None) is not None:
with tf.name_scope(self.out_proj.name):
self.out_proj.build([None, None, self.config.hidden_size])
@add_start_docstrings(
"""
@ -1278,6 +1459,17 @@ class TFElectraForSequenceClassification(TFElectraPreTrainedModel, TFSequenceCla
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "electra", None) is not None:
with tf.name_scope(self.electra.name):
self.electra.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build(None)
@add_start_docstrings(
"""
@ -1297,6 +1489,7 @@ class TFElectraForMultipleChoice(TFElectraPreTrainedModel, TFMultipleChoiceLoss)
self.classifier = tf.keras.layers.Dense(
1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
@ -1370,6 +1563,20 @@ class TFElectraForMultipleChoice(TFElectraPreTrainedModel, TFMultipleChoiceLoss)
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "electra", None) is not None:
with tf.name_scope(self.electra.name):
self.electra.build(None)
if getattr(self, "sequence_summary", None) is not None:
with tf.name_scope(self.sequence_summary.name):
self.sequence_summary.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.hidden_size])
@add_start_docstrings(
"""
@ -1391,6 +1598,7 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassific
self.classifier = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ -1448,6 +1656,17 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassific
attentions=discriminator_hidden_states.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "electra", None) is not None:
with tf.name_scope(self.electra.name):
self.electra.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.hidden_size])
@add_start_docstrings(
"""
@ -1465,6 +1684,7 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin
self.qa_outputs = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ -1541,3 +1761,14 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin
hidden_states=discriminator_hidden_states.hidden_states,
attentions=discriminator_hidden_states.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "electra", None) is not None:
with tf.name_scope(self.electra.name):
self.electra.build(None)
if getattr(self, "qa_outputs", None) is not None:
with tf.name_scope(self.qa_outputs.name):
self.qa_outputs.build([None, None, self.config.hidden_size])

View File

@ -650,3 +650,17 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
def _reorder_cache(self, past, beam_idx):
# apply decoder cache reordering here
return self.decoder._reorder_cache(past, beam_idx)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "enc_to_dec_proj", None) is not None:
with tf.name_scope(self.enc_to_dec_proj.name):
self.enc_to_dec_proj.build([None, None, self.encoder.config.hidden_size])
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
if getattr(self, "decoder", None) is not None:
with tf.name_scope(self.decoder.name):
self.decoder.build(None)

View File

@ -149,10 +149,13 @@ class TFEsmContactPredictionHead(Layer):
self.in_features = in_features
self.regression = Dense(1, use_bias=bias, activation="sigmoid", name="regression")
def build(self, input_shape):
super().build(input_shape)
with tf.name_scope("regression"):
self.regression.build((None, self.in_features))
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "regression", None) is not None:
with tf.name_scope(self.regression.name):
self.regression.build((None, self.in_features))
def call(self, tokens, attentions):
# remove eos token attentions
@ -268,6 +271,20 @@ class TFEsmEmbeddings(Layer):
)
return tf.broadcast_to(tf.expand_dims(position_ids, 0), input_shape)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "word_embeddings", None) is not None:
with tf.name_scope(self.word_embeddings.name):
self.word_embeddings.build(None)
if getattr(self, "position_embeddings", None) is not None:
with tf.name_scope(self.position_embeddings.name):
self.position_embeddings.build(None)
if getattr(self, "layer_norm", None) is not None:
with tf.name_scope(self.layer_norm.name):
self.layer_norm.build([None, None, self.config.hidden_size])
class TFEsmSelfAttention(Layer):
def __init__(self, config, position_embedding_type=None, name=None):
@ -306,6 +323,7 @@ class TFEsmSelfAttention(Layer):
self.rotary_embeddings = TFRotaryEmbedding(dim=self.attention_head_size, name="rotary_embeddings")
self.is_decoder = config.is_decoder
self.config = config
def transpose_for_scores(self, x: tf.Tensor) -> tf.Tensor:
new_x_shape = shape_list(x)[:-1] + [self.num_attention_heads, self.attention_head_size]
@ -415,6 +433,23 @@ class TFEsmSelfAttention(Layer):
outputs = outputs + (past_key_value,)
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "query", None) is not None:
with tf.name_scope(self.query.name):
self.query.build([None, None, self.config.hidden_size])
if getattr(self, "key", None) is not None:
with tf.name_scope(self.key.name):
self.key.build([None, None, self.config.hidden_size])
if getattr(self, "value", None) is not None:
with tf.name_scope(self.value.name):
self.value.build([None, None, self.config.hidden_size])
if getattr(self, "rotary_embeddings", None) is not None:
with tf.name_scope(self.rotary_embeddings.name):
self.rotary_embeddings.build(None)
class TFEsmSelfOutput(Layer):
def __init__(self, config, name=None):
@ -423,6 +458,7 @@ class TFEsmSelfOutput(Layer):
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
self.dropout = Dropout(config.hidden_dropout_prob)
self.config = config
def call(self, hidden_states, input_tensor, training=False):
hidden_states = self.dense(hidden_states)
@ -430,6 +466,14 @@ class TFEsmSelfOutput(Layer):
hidden_states += input_tensor
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
class TFEsmAttention(Layer):
def __init__(self, config, name=None):
@ -438,6 +482,7 @@ class TFEsmAttention(Layer):
self.output_layer = TFEsmSelfOutput(config, name="output")
self.pruned_heads = set()
self.LayerNorm = LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.config = config
def prune_heads(self, heads):
raise NotImplementedError
@ -468,6 +513,20 @@ class TFEsmAttention(Layer):
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "self", None) is not None:
with tf.name_scope(self.self.name):
self.self.build(None)
if getattr(self, "output_layer", None) is not None:
with tf.name_scope(self.output_layer.name):
self.output_layer.build(None)
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
class TFEsmIntermediate(tf.keras.layers.Layer):
def __init__(self, config: EsmConfig, **kwargs):
@ -478,12 +537,21 @@ class TFEsmIntermediate(tf.keras.layers.Layer):
kernel_initializer=get_initializer(config.initializer_range),
name="dense",
)
self.config = config
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
hidden_states = tf.nn.gelu(hidden_states)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
class TFEsmOutput(Layer):
def __init__(self, config, name=None):
@ -492,6 +560,7 @@ class TFEsmOutput(Layer):
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
self.dropout = Dropout(config.hidden_dropout_prob)
self.config = config
def call(self, hidden_states, input_tensor, training=False):
hidden_states = self.dense(hidden_states)
@ -499,6 +568,14 @@ class TFEsmOutput(Layer):
hidden_states += input_tensor
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.intermediate_size])
class TFEsmLayer(Layer):
def __init__(self, config, name=None):
@ -515,6 +592,7 @@ class TFEsmLayer(Layer):
self.intermediate = TFEsmIntermediate(config, name="intermediate")
self.output_layer = TFEsmOutput(config, name="output")
self.LayerNorm = LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.config = config
def call(
self,
@ -586,6 +664,23 @@ class TFEsmLayer(Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "attention", None) is not None:
with tf.name_scope(self.attention.name):
self.attention.build(None)
if getattr(self, "intermediate", None) is not None:
with tf.name_scope(self.intermediate.name):
self.intermediate.build(None)
if getattr(self, "output_layer", None) is not None:
with tf.name_scope(self.output_layer.name):
self.output_layer.build(None)
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
class TFEsmEncoder(Layer):
def __init__(self, config, name=None):
@ -665,6 +760,18 @@ class TFEsmEncoder(Layer):
cross_attentions=all_cross_attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "emb_layer_norm_after", None) is not None:
with tf.name_scope(self.emb_layer_norm_after.name):
self.emb_layer_norm_after.build([None, None, self.config.hidden_size])
if getattr(self, "layer", None) is not None:
for layer in self.layer:
with tf.name_scope(layer.name):
layer.build(None)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertPooler with Bert->Esm
class TFEsmPooler(tf.keras.layers.Layer):
@ -677,6 +784,7 @@ class TFEsmPooler(tf.keras.layers.Layer):
activation="tanh",
name="dense",
)
self.config = config
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
# We "pool" the model by simply taking the hidden state corresponding
@ -686,6 +794,14 @@ class TFEsmPooler(tf.keras.layers.Layer):
return pooled_output
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
class TFEsmPreTrainedModel(TFPreTrainedModel):
"""
@ -787,10 +903,22 @@ class TFEsmMainLayer(Layer):
in_features=self.config.num_hidden_layers * self.config.num_attention_heads, bias=True, name="contact_head"
)
def build(self, input_shape):
super().build(input_shape)
with tf.name_scope("contact_head"):
self.contact_head.build(input_shape)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embeddings", None) is not None:
with tf.name_scope(self.embeddings.name):
self.embeddings.build(None)
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
if getattr(self, "pooler", None) is not None:
with tf.name_scope(self.pooler.name):
self.pooler.build(None)
if getattr(self, "contact_head", None) is not None:
with tf.name_scope(self.contact_head.name):
self.contact_head.build(None)
def get_input_embeddings(self):
return self.embeddings.word_embeddings
@ -1041,6 +1169,14 @@ class TFEsmModel(TFEsmPreTrainedModel):
def predict_contacts(self, tokens, attention_mask):
return self.esm.predict_contacts(tokens, attention_mask)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "esm", None) is not None:
with tf.name_scope(self.esm.name):
self.esm.build(None)
@add_start_docstrings("""ESM Model with a `language modeling` head on top.""", ESM_START_DOCSTRING)
class TFEsmForMaskedLM(TFEsmPreTrainedModel, TFMaskedLanguageModelingLoss):
@ -1140,6 +1276,17 @@ class TFEsmForMaskedLM(TFEsmPreTrainedModel, TFMaskedLanguageModelingLoss):
def predict_contacts(self, tokens, attention_mask):
return self.esm.predict_contacts(tokens, attention_mask)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "esm", None) is not None:
with tf.name_scope(self.esm.name):
self.esm.build(None)
if getattr(self, "lm_head", None) is not None:
with tf.name_scope(self.lm_head.name):
self.lm_head.build(None)
class TFEsmLMHead(Layer):
"""ESM Head for masked language modeling."""
@ -1162,11 +1309,22 @@ class TFEsmLMHead(Layer):
)
self.config = config
def build(self, input_shape):
super().build(input_shape)
def build(self, input_shape=None):
# Separate bias to match the PT model and allow weight cross-loading to work
# Put it in the build so it gets the right name when adding it as a weight
if self.built:
return
self.built = True
self.bias = self.add_weight("bias", shape=(self.config.vocab_size,), initializer="zeros", trainable=True)
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
if getattr(self, "layer_norm", None) is not None:
with tf.name_scope(self.layer_norm.name):
self.layer_norm.build([None, None, self.config.hidden_size])
if getattr(self, "decoder", None) is not None and not self.config.tie_word_embeddings:
with tf.name_scope(self.decoder.name):
self.decoder.build([None, None, self.config.hidden_size])
def get_bias(self):
return {"bias": self.bias}
@ -1257,6 +1415,17 @@ class TFEsmForSequenceClassification(TFEsmPreTrainedModel, TFSequenceClassificat
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "esm", None) is not None:
with tf.name_scope(self.esm.name):
self.esm.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build(None)
@add_start_docstrings(
"""
@ -1276,6 +1445,7 @@ class TFEsmForTokenClassification(TFEsmPreTrainedModel, TFTokenClassificationLos
self.esm = TFEsmMainLayer(config, add_pooling_layer=False, name="esm")
self.dropout = Dropout(config.hidden_dropout_prob)
self.classifier = Dense(config.num_labels, name="classifier")
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(ESM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ -1333,6 +1503,17 @@ class TFEsmForTokenClassification(TFEsmPreTrainedModel, TFTokenClassificationLos
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "esm", None) is not None:
with tf.name_scope(self.esm.name):
self.esm.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.hidden_size])
class TFEsmClassificationHead(Layer):
"""Head for sentence-level classification tasks."""
@ -1352,6 +1533,7 @@ class TFEsmClassificationHead(Layer):
activation="linear",
name="out_proj",
)
self.config = config
def call(self, features, training=False):
x = features[:, 0, :] # take <s> token (equiv. to [CLS])
@ -1361,6 +1543,17 @@ class TFEsmClassificationHead(Layer):
x = self.out_proj(x)
return x
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
if getattr(self, "out_proj", None) is not None:
with tf.name_scope(self.out_proj.name):
self.out_proj.build([None, None, self.config.hidden_size])
def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
"""

View File

@ -290,6 +290,14 @@ class TFFlaubertModel(TFFlaubertPreTrainedModel):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "transformer", None) is not None:
with tf.name_scope(self.transformer.name):
self.transformer.build(None)
# Copied from transformers.models.xlm.modeling_tf_xlm.TFXLMMultiHeadAttention with XLM->Flaubert
class TFFlaubertMultiHeadAttention(tf.keras.layers.Layer):
@ -309,6 +317,7 @@ class TFFlaubertMultiHeadAttention(tf.keras.layers.Layer):
self.out_lin = tf.keras.layers.Dense(dim, kernel_initializer=get_initializer(config.init_std), name="out_lin")
self.dropout = tf.keras.layers.Dropout(config.attention_dropout)
self.pruned_heads = set()
self.dim = dim
def prune_heads(self, heads):
raise NotImplementedError
@ -383,6 +392,23 @@ class TFFlaubertMultiHeadAttention(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "q_lin", None) is not None:
with tf.name_scope(self.q_lin.name):
self.q_lin.build([None, None, self.dim])
if getattr(self, "k_lin", None) is not None:
with tf.name_scope(self.k_lin.name):
self.k_lin.build([None, None, self.dim])
if getattr(self, "v_lin", None) is not None:
with tf.name_scope(self.v_lin.name):
self.v_lin.build([None, None, self.dim])
if getattr(self, "out_lin", None) is not None:
with tf.name_scope(self.out_lin.name):
self.out_lin.build([None, None, self.dim])
# Copied from transformers.models.xlm.modeling_tf_xlm.TFXLMTransformerFFN
class TFFlaubertTransformerFFN(tf.keras.layers.Layer):
@ -393,6 +419,8 @@ class TFFlaubertTransformerFFN(tf.keras.layers.Layer):
self.lin2 = tf.keras.layers.Dense(out_dim, kernel_initializer=get_initializer(config.init_std), name="lin2")
self.act = get_tf_activation("gelu") if config.gelu_activation else get_tf_activation("relu")
self.dropout = tf.keras.layers.Dropout(config.dropout)
self.in_dim = in_dim
self.dim_hidden = dim_hidden
def call(self, input, training=False):
x = self.lin1(input)
@ -402,6 +430,17 @@ class TFFlaubertTransformerFFN(tf.keras.layers.Layer):
return x
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "lin1", None) is not None:
with tf.name_scope(self.lin1.name):
self.lin1.build([None, None, self.in_dim])
if getattr(self, "lin2", None) is not None:
with tf.name_scope(self.lin2.name):
self.lin2.build([None, None, self.dim_hidden])
@keras_serializable
class TFFlaubertMainLayer(tf.keras.layers.Layer):
@ -454,7 +493,7 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name=f"layer_norm2_._{i}")
)
def build(self, input_shape):
def build(self, input_shape=None):
with tf.name_scope("position_embeddings"):
self.position_embeddings = self.add_weight(
name="embeddings",
@ -470,7 +509,27 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
initializer=get_initializer(self.embed_init_std),
)
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "embeddings", None) is not None:
with tf.name_scope(self.embeddings.name):
self.embeddings.build(None)
if getattr(self, "layer_norm_emb", None) is not None:
with tf.name_scope(self.layer_norm_emb.name):
self.layer_norm_emb.build([None, None, self.dim])
for layer in self.attentions:
with tf.name_scope(layer.name):
layer.build(None)
for layer in self.layer_norm1:
with tf.name_scope(layer.name):
layer.build([None, None, self.dim])
for layer in self.ffns:
with tf.name_scope(layer.name):
layer.build(None)
for layer in self.layer_norm2:
with tf.name_scope(layer.name):
layer.build([None, None, self.dim])
def get_input_embeddings(self):
return self.embeddings
@ -841,6 +900,17 @@ class TFFlaubertWithLMHeadModel(TFFlaubertPreTrainedModel):
logits=outputs, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "transformer", None) is not None:
with tf.name_scope(self.transformer.name):
self.transformer.build(None)
if getattr(self, "pred_layer", None) is not None:
with tf.name_scope(self.pred_layer.name):
self.pred_layer.build(None)
@add_start_docstrings(
"""
@ -920,6 +990,17 @@ class TFFlaubertForSequenceClassification(TFFlaubertPreTrainedModel, TFSequenceC
attentions=transformer_outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "transformer", None) is not None:
with tf.name_scope(self.transformer.name):
self.transformer.build(None)
if getattr(self, "sequence_summary", None) is not None:
with tf.name_scope(self.sequence_summary.name):
self.sequence_summary.build(None)
@add_start_docstrings(
"""
@ -936,6 +1017,7 @@ class TFFlaubertForQuestionAnsweringSimple(TFFlaubertPreTrainedModel, TFQuestion
self.qa_outputs = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.init_std), name="qa_outputs"
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(FLAUBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ -1012,6 +1094,17 @@ class TFFlaubertForQuestionAnsweringSimple(TFFlaubertPreTrainedModel, TFQuestion
attentions=transformer_outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "transformer", None) is not None:
with tf.name_scope(self.transformer.name):
self.transformer.build(None)
if getattr(self, "qa_outputs", None) is not None:
with tf.name_scope(self.qa_outputs.name):
self.qa_outputs.build([None, None, self.config.hidden_size])
@add_start_docstrings(
"""
@ -1031,6 +1124,7 @@ class TFFlaubertForTokenClassification(TFFlaubertPreTrainedModel, TFTokenClassif
self.classifier = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.init_std), name="classifier"
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(FLAUBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ -1093,6 +1187,17 @@ class TFFlaubertForTokenClassification(TFFlaubertPreTrainedModel, TFTokenClassif
attentions=transformer_outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "transformer", None) is not None:
with tf.name_scope(self.transformer.name):
self.transformer.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.hidden_size])
@add_start_docstrings(
"""
@ -1111,6 +1216,7 @@ class TFFlaubertForMultipleChoice(TFFlaubertPreTrainedModel, TFMultipleChoiceLos
self.logits_proj = tf.keras.layers.Dense(
1, kernel_initializer=get_initializer(config.initializer_range), name="logits_proj"
)
self.config = config
@property
def dummy_inputs(self):
@ -1214,3 +1320,17 @@ class TFFlaubertForMultipleChoice(TFFlaubertPreTrainedModel, TFMultipleChoiceLos
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "transformer", None) is not None:
with tf.name_scope(self.transformer.name):
self.transformer.build(None)
if getattr(self, "sequence_summary", None) is not None:
with tf.name_scope(self.sequence_summary.name):
self.sequence_summary.build(None)
if getattr(self, "logits_proj", None) is not None:
with tf.name_scope(self.logits_proj.name):
self.logits_proj.build([None, None, self.config.num_labels])

View File

@ -90,7 +90,7 @@ class TFFunnelEmbeddings(tf.keras.layers.Layer):
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout)
def build(self, input_shape):
def build(self, input_shape=None):
with tf.name_scope("word_embeddings"):
self.weight = self.add_weight(
name="weight",
@ -98,7 +98,12 @@ class TFFunnelEmbeddings(tf.keras.layers.Layer):
initializer=get_initializer(initializer_range=self.initializer_std),
)
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.d_model])
def call(self, input_ids=None, inputs_embeds=None, training=False):
"""
@ -407,7 +412,7 @@ class TFFunnelRelMultiheadAttention(tf.keras.layers.Layer):
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
self.scale = 1.0 / (d_head**0.5)
def build(self, input_shape):
def build(self, input_shape=None):
n_head, d_head, d_model = self.n_head, self.d_head, self.d_model
initializer = get_initializer(self.initializer_range)
@ -426,7 +431,25 @@ class TFFunnelRelMultiheadAttention(tf.keras.layers.Layer):
self.seg_embed = self.add_weight(
shape=(2, n_head, d_head), initializer=initializer, trainable=True, name="seg_embed"
)
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "q_head", None) is not None:
with tf.name_scope(self.q_head.name):
self.q_head.build([None, None, d_model])
if getattr(self, "k_head", None) is not None:
with tf.name_scope(self.k_head.name):
self.k_head.build([None, None, d_model])
if getattr(self, "v_head", None) is not None:
with tf.name_scope(self.v_head.name):
self.v_head.build([None, None, d_model])
if getattr(self, "post_proj", None) is not None:
with tf.name_scope(self.post_proj.name):
self.post_proj.build([None, None, n_head * d_head])
if getattr(self, "layer_norm", None) is not None:
with tf.name_scope(self.layer_norm.name):
self.layer_norm.build([None, None, d_model])
def relative_positional_attention(self, position_embeds, q_head, context_len, cls_mask=None):
"""Relative attention score for the positional encodings"""
@ -557,6 +580,7 @@ class TFFunnelPositionwiseFFN(tf.keras.layers.Layer):
self.linear_2 = tf.keras.layers.Dense(config.d_model, kernel_initializer=initializer, name="linear_2")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout)
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
self.config = config
def call(self, hidden, training=False):
h = self.linear_1(hidden)
@ -566,6 +590,20 @@ class TFFunnelPositionwiseFFN(tf.keras.layers.Layer):
h = self.dropout(h, training=training)
return self.layer_norm(hidden + h)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "linear_1", None) is not None:
with tf.name_scope(self.linear_1.name):
self.linear_1.build([None, None, self.config.d_model])
if getattr(self, "linear_2", None) is not None:
with tf.name_scope(self.linear_2.name):
self.linear_2.build([None, None, self.config.d_inner])
if getattr(self, "layer_norm", None) is not None:
with tf.name_scope(self.layer_norm.name):
self.layer_norm.build([None, None, self.config.d_model])
class TFFunnelLayer(tf.keras.layers.Layer):
def __init__(self, config, block_index, **kwargs):
@ -580,6 +618,17 @@ class TFFunnelLayer(tf.keras.layers.Layer):
output = self.ffn(attn[0], training=training)
return (output, attn[1]) if output_attentions else (output,)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "attention", None) is not None:
with tf.name_scope(self.attention.name):
self.attention.build(None)
if getattr(self, "ffn", None) is not None:
with tf.name_scope(self.ffn.name):
self.ffn.build(None)
class TFFunnelEncoder(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
@ -650,6 +699,15 @@ class TFFunnelEncoder(tf.keras.layers.Layer):
return tuple(v for v in [hidden, all_hidden_states, all_attentions] if v is not None)
return TFBaseModelOutput(last_hidden_state=hidden, hidden_states=all_hidden_states, attentions=all_attentions)
def build(self, input_shape=None):
if self.built:
return
self.built = True
for block in self.blocks:
for layer in block:
with tf.name_scope(layer.name):
layer.build(None)
def upsample(x, stride, target_len, separate_cls=True, truncate_seq=False):
"""
@ -725,6 +783,15 @@ class TFFunnelDecoder(tf.keras.layers.Layer):
return tuple(v for v in [hidden, all_hidden_states, all_attentions] if v is not None)
return TFBaseModelOutput(last_hidden_state=hidden, hidden_states=all_hidden_states, attentions=all_attentions)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "layers", None) is not None:
for layer in self.layers:
with tf.name_scope(layer.name):
layer.build(None)
@keras_serializable
class TFFunnelBaseLayer(tf.keras.layers.Layer):
@ -795,6 +862,17 @@ class TFFunnelBaseLayer(tf.keras.layers.Layer):
return encoder_outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embeddings", None) is not None:
with tf.name_scope(self.embeddings.name):
self.embeddings.build(None)
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
@keras_serializable
class TFFunnelMainLayer(tf.keras.layers.Layer):
@ -895,6 +973,20 @@ class TFFunnelMainLayer(tf.keras.layers.Layer):
attentions=(encoder_outputs.attentions + decoder_outputs.attentions) if output_attentions else None,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embeddings", None) is not None:
with tf.name_scope(self.embeddings.name):
self.embeddings.build(None)
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
if getattr(self, "decoder", None) is not None:
with tf.name_scope(self.decoder.name):
self.decoder.build(None)
class TFFunnelDiscriminatorPredictions(tf.keras.layers.Layer):
"""Prediction module for the discriminator, made up of two dense layers."""
@ -905,6 +997,7 @@ class TFFunnelDiscriminatorPredictions(tf.keras.layers.Layer):
self.dense = tf.keras.layers.Dense(config.d_model, kernel_initializer=initializer, name="dense")
self.activation_function = get_tf_activation(config.hidden_act)
self.dense_prediction = tf.keras.layers.Dense(1, kernel_initializer=initializer, name="dense_prediction")
self.config = config
def call(self, discriminator_hidden_states):
hidden_states = self.dense(discriminator_hidden_states)
@ -912,6 +1005,17 @@ class TFFunnelDiscriminatorPredictions(tf.keras.layers.Layer):
logits = tf.squeeze(self.dense_prediction(hidden_states))
return logits
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.d_model])
if getattr(self, "dense_prediction", None) is not None:
with tf.name_scope(self.dense_prediction.name):
self.dense_prediction.build([None, None, self.config.d_model])
class TFFunnelMaskedLMHead(tf.keras.layers.Layer):
def __init__(self, config, input_embeddings, **kwargs):
@ -958,6 +1062,7 @@ class TFFunnelClassificationHead(tf.keras.layers.Layer):
)
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout)
self.linear_out = tf.keras.layers.Dense(n_labels, kernel_initializer=initializer, name="linear_out")
self.config = config
def call(self, hidden, training=False):
hidden = self.linear_hidden(hidden)
@ -965,6 +1070,17 @@ class TFFunnelClassificationHead(tf.keras.layers.Layer):
hidden = self.dropout(hidden, training=training)
return self.linear_out(hidden)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "linear_hidden", None) is not None:
with tf.name_scope(self.linear_hidden.name):
self.linear_hidden.build([None, None, self.config.d_model])
if getattr(self, "linear_out", None) is not None:
with tf.name_scope(self.linear_out.name):
self.linear_out.build([None, None, self.config.d_model])
class TFFunnelPreTrainedModel(TFPreTrainedModel):
"""
@ -1147,6 +1263,14 @@ class TFFunnelBaseModel(TFFunnelPreTrainedModel):
attentions=output.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "funnel", None) is not None:
with tf.name_scope(self.funnel.name):
self.funnel.build(None)
@add_start_docstrings(
"The bare Funnel Transformer Model transformer outputting raw hidden-states without any specific head on top.",
@ -1195,6 +1319,14 @@ class TFFunnelModel(TFFunnelPreTrainedModel):
attentions=output.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "funnel", None) is not None:
with tf.name_scope(self.funnel.name):
self.funnel.build(None)
@add_start_docstrings(
"""
@ -1268,6 +1400,17 @@ class TFFunnelForPreTraining(TFFunnelPreTrainedModel):
logits=output.logits, hidden_states=output.hidden_states, attentions=output.attentions
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "funnel", None) is not None:
with tf.name_scope(self.funnel.name):
self.funnel.build(None)
if getattr(self, "discriminator_predictions", None) is not None:
with tf.name_scope(self.discriminator_predictions.name):
self.discriminator_predictions.build(None)
@add_start_docstrings("""Funnel Model with a `language modeling` head on top.""", FUNNEL_START_DOCSTRING)
class TFFunnelForMaskedLM(TFFunnelPreTrainedModel, TFMaskedLanguageModelingLoss):
@ -1340,6 +1483,17 @@ class TFFunnelForMaskedLM(TFFunnelPreTrainedModel, TFMaskedLanguageModelingLoss)
# different dimensions
return TFMaskedLMOutput(logits=output.logits, hidden_states=output.hidden_states, attentions=output.attentions)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "funnel", None) is not None:
with tf.name_scope(self.funnel.name):
self.funnel.build(None)
if getattr(self, "lm_head", None) is not None:
with tf.name_scope(self.lm_head.name):
self.lm_head.build(None)
@add_start_docstrings(
"""
@ -1415,6 +1569,17 @@ class TFFunnelForSequenceClassification(TFFunnelPreTrainedModel, TFSequenceClass
logits=output.logits, hidden_states=output.hidden_states, attentions=output.attentions
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "funnel", None) is not None:
with tf.name_scope(self.funnel.name):
self.funnel.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build(None)
@add_start_docstrings(
"""
@ -1510,6 +1675,17 @@ class TFFunnelForMultipleChoice(TFFunnelPreTrainedModel, TFMultipleChoiceLoss):
logits=output.logits, hidden_states=output.hidden_states, attentions=output.attentions
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "funnel", None) is not None:
with tf.name_scope(self.funnel.name):
self.funnel.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build(None)
@add_start_docstrings(
"""
@ -1528,6 +1704,7 @@ class TFFunnelForTokenClassification(TFFunnelPreTrainedModel, TFTokenClassificat
self.classifier = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(FUNNEL_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ -1587,6 +1764,17 @@ class TFFunnelForTokenClassification(TFFunnelPreTrainedModel, TFTokenClassificat
logits=output.logits, hidden_states=output.hidden_states, attentions=output.attentions
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "funnel", None) is not None:
with tf.name_scope(self.funnel.name):
self.funnel.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.hidden_size])
@add_start_docstrings(
"""
@ -1604,6 +1792,7 @@ class TFFunnelForQuestionAnswering(TFFunnelPreTrainedModel, TFQuestionAnsweringL
self.qa_outputs = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(FUNNEL_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ -1679,3 +1868,14 @@ class TFFunnelForQuestionAnswering(TFFunnelPreTrainedModel, TFQuestionAnsweringL
hidden_states=output.hidden_states,
attentions=output.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "funnel", None) is not None:
with tf.name_scope(self.funnel.name):
self.funnel.build(None)
if getattr(self, "qa_outputs", None) is not None:
with tf.name_scope(self.qa_outputs.name):
self.qa_outputs.build([None, None, self.config.hidden_size])

View File

@ -91,6 +91,7 @@ class TFAttention(tf.keras.layers.Layer):
self.attn_dropout = tf.keras.layers.Dropout(config.attn_pdrop)
self.resid_dropout = tf.keras.layers.Dropout(config.resid_pdrop)
self.pruned_heads = set()
self.embed_dim = n_state
def prune_heads(self, heads):
pass
@ -202,6 +203,24 @@ class TFAttention(tf.keras.layers.Layer):
outputs = [a, present] + attn_outputs[1:]
return outputs # a, present, (attentions)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if self.is_cross_attention:
c_attn_shape = 2 * self.embed_dim
else:
c_attn_shape = 3 * self.embed_dim
if getattr(self, "c_proj", None) is not None:
with tf.name_scope(self.c_proj.name):
self.c_proj.build([None, None, self.embed_dim])
if getattr(self, "c_attn", None) is not None:
with tf.name_scope(self.c_attn.name):
self.c_attn.build([None, None, c_attn_shape])
if getattr(self, "q_attn", None) is not None:
with tf.name_scope(self.q_attn.name):
self.q_attn.build([None, None, self.embed_dim])
class TFMLP(tf.keras.layers.Layer):
def __init__(self, n_state, config, **kwargs):
@ -211,6 +230,8 @@ class TFMLP(tf.keras.layers.Layer):
self.c_proj = TFConv1D(nx, n_state, initializer_range=config.initializer_range, name="c_proj")
self.act = get_tf_activation(config.activation_function)
self.dropout = tf.keras.layers.Dropout(config.resid_pdrop)
self.intermediate_size = n_state
self.embed_dim = nx
def call(self, x, training=False):
h = self.act(self.c_fc(x))
@ -218,6 +239,17 @@ class TFMLP(tf.keras.layers.Layer):
h2 = self.dropout(h2, training=training)
return h2
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "c_fc", None) is not None:
with tf.name_scope(self.c_fc.name):
self.c_fc.build([None, None, self.intermediate_size])
if getattr(self, "c_proj", None) is not None:
with tf.name_scope(self.c_proj.name):
self.c_proj.build([None, None, self.embed_dim])
class TFBlock(tf.keras.layers.Layer):
def __init__(self, config, scale=False, **kwargs):
@ -235,6 +267,7 @@ class TFBlock(tf.keras.layers.Layer):
)
self.mlp = TFMLP(inner_dim, config, name="mlp")
self.hidden_size = config.hidden_size
def call(
self,
@ -296,6 +329,29 @@ class TFBlock(tf.keras.layers.Layer):
outputs = [x] + outputs
return outputs # x, present, (attentions, cross_attentions)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "ln_1", None) is not None:
with tf.name_scope(self.ln_1.name):
self.ln_1.build([None, None, self.hidden_size])
if getattr(self, "attn", None) is not None:
with tf.name_scope(self.attn.name):
self.attn.build(None)
if getattr(self, "ln_2", None) is not None:
with tf.name_scope(self.ln_2.name):
self.ln_2.build([None, None, self.hidden_size])
if getattr(self, "mlp", None) is not None:
with tf.name_scope(self.mlp.name):
self.mlp.build(None)
if getattr(self, "crossattention", None) is not None:
with tf.name_scope(self.crossattention.name):
self.crossattention.build(None)
if getattr(self, "ln_cross_attn", None) is not None:
with tf.name_scope(self.ln_cross_attn.name):
self.ln_cross_attn.build([None, None, self.hidden_size])
@keras_serializable
class TFGPT2MainLayer(tf.keras.layers.Layer):
@ -330,6 +386,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
self.drop = tf.keras.layers.Dropout(config.embd_pdrop)
self.h = [TFBlock(config, scale=True, name=f"h_._{i}") for i in range(config.n_layer)]
self.ln_f = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_f")
self.embed_dim = config.hidden_size
def get_input_embeddings(self):
return self.wte
@ -509,6 +566,24 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
cross_attentions=all_cross_attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "wte", None) is not None:
with tf.name_scope(self.wte.name):
self.wte.build(None)
if getattr(self, "wpe", None) is not None:
with tf.name_scope(self.wpe.name):
self.wpe.build(None)
if getattr(self, "ln_f", None) is not None:
with tf.name_scope(self.ln_f.name):
self.ln_f.build([None, None, self.embed_dim])
if getattr(self, "h", None) is not None:
for layer in self.h:
with tf.name_scope(layer.name):
layer.build(None)
class TFGPT2PreTrainedModel(TFPreTrainedModel):
"""
@ -751,6 +826,14 @@ class TFGPT2Model(TFGPT2PreTrainedModel):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "transformer", None) is not None:
with tf.name_scope(self.transformer.name):
self.transformer.build(None)
@add_start_docstrings(
"""
@ -883,6 +966,14 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
cross_attentions=transformer_outputs.cross_attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "transformer", None) is not None:
with tf.name_scope(self.transformer.name):
self.transformer.build(None)
@add_start_docstrings(
"""
@ -1012,6 +1103,17 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
"mc_token_ids": tf.TensorSpec((None, None), tf.int32, name="mc_token_ids"),
}
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "transformer", None) is not None:
with tf.name_scope(self.transformer.name):
self.transformer.build(None)
if getattr(self, "multiple_choice_head", None) is not None:
with tf.name_scope(self.multiple_choice_head.name):
self.multiple_choice_head.build(None)
@add_start_docstrings(
"""
@ -1039,6 +1141,7 @@ class TFGPT2ForSequenceClassification(TFGPT2PreTrainedModel, TFSequenceClassific
use_bias=False,
)
self.transformer = TFGPT2MainLayer(config, name="transformer")
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
@ -1127,3 +1230,14 @@ class TFGPT2ForSequenceClassification(TFGPT2PreTrainedModel, TFSequenceClassific
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "score", None) is not None:
with tf.name_scope(self.score.name):
self.score.build([None, None, self.config.n_embd])
if getattr(self, "transformer", None) is not None:
with tf.name_scope(self.transformer.name):
self.transformer.build(None)

View File

@ -267,6 +267,23 @@ class TFGPTJAttention(tf.keras.layers.Layer):
return outputs # a, present, (attentions)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "q_proj", None) is not None:
with tf.name_scope(self.q_proj.name):
self.q_proj.build([None, None, self.embed_dim])
if getattr(self, "k_proj", None) is not None:
with tf.name_scope(self.k_proj.name):
self.k_proj.build([None, None, self.embed_dim])
if getattr(self, "v_proj", None) is not None:
with tf.name_scope(self.v_proj.name):
self.v_proj.build([None, None, self.embed_dim])
if getattr(self, "out_proj", None) is not None:
with tf.name_scope(self.out_proj.name):
self.out_proj.build([None, None, self.embed_dim])
class TFGPTJMLP(tf.keras.layers.Layer):
def __init__(self, intermediate_size: int, config: GPTJConfig, **kwargs):
@ -282,6 +299,8 @@ class TFGPTJMLP(tf.keras.layers.Layer):
self.act = get_tf_activation(config.activation_function)
self.dropout = tf.keras.layers.Dropout(config.embd_pdrop)
self.embed_dim = config.n_embd
self.intermediate_size = intermediate_size
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.fc_in(hidden_states)
@ -290,6 +309,17 @@ class TFGPTJMLP(tf.keras.layers.Layer):
hidden_states = self.dropout(hidden_states)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "fc_in", None) is not None:
with tf.name_scope(self.fc_in.name):
self.fc_in.build([None, None, self.embed_dim])
if getattr(self, "fc_out", None) is not None:
with tf.name_scope(self.fc_out.name):
self.fc_out.build([None, None, self.intermediate_size])
class TFGPTJBlock(tf.keras.layers.Layer):
def __init__(self, config: GPTJConfig, **kwargs):
@ -298,6 +328,7 @@ class TFGPTJBlock(tf.keras.layers.Layer):
self.ln_1 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_1")
self.attn = TFGPTJAttention(config, name="attn")
self.mlp = TFGPTJMLP(inner_dim, config, name="mlp")
self.config = config
def call(
self,
@ -332,6 +363,20 @@ class TFGPTJBlock(tf.keras.layers.Layer):
outputs = (hidden_states,) + outputs[1:]
return outputs # hidden_states, present, (attentions)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "ln_1", None) is not None:
with tf.name_scope(self.ln_1.name):
self.ln_1.build([None, None, self.config.n_embd])
if getattr(self, "attn", None) is not None:
with tf.name_scope(self.attn.name):
self.attn.build(None)
if getattr(self, "mlp", None) is not None:
with tf.name_scope(self.mlp.name):
self.mlp.build(None)
@keras_serializable
class TFGPTJMainLayer(tf.keras.layers.Layer):
@ -357,6 +402,7 @@ class TFGPTJMainLayer(tf.keras.layers.Layer):
self.drop = tf.keras.layers.Dropout(config.embd_pdrop)
self.h = [TFGPTJBlock(config, name=f"h_._{i}") for i in range(config.n_layer)]
self.ln_f = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_f")
self.embed_dim = config.n_embd
def get_input_embeddings(self):
return self.wte
@ -500,6 +546,21 @@ class TFGPTJMainLayer(tf.keras.layers.Layer):
attentions=all_attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "wte", None) is not None:
with tf.name_scope(self.wte.name):
self.wte.build(None)
if getattr(self, "ln_f", None) is not None:
with tf.name_scope(self.ln_f.name):
self.ln_f.build([None, None, self.embed_dim])
if getattr(self, "h", None) is not None:
for layer in self.h:
with tf.name_scope(layer.name):
layer.build(None)
class TFGPTJPreTrainedModel(TFPreTrainedModel):
"""
@ -672,6 +733,14 @@ class TFGPTJModel(TFGPTJPreTrainedModel):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "transformer", None) is not None:
with tf.name_scope(self.transformer.name):
self.transformer.build(None)
@add_start_docstrings(
"""
@ -686,6 +755,7 @@ class TFGPTJForCausalLM(TFGPTJPreTrainedModel, TFCausalLanguageModelingLoss):
self.lm_head = tf.keras.layers.Dense(
config.vocab_size, kernel_initializer=get_initializer(config.initializer_range), name="lm_head"
)
self.config = config
def get_output_embeddings(self):
return self.lm_head
@ -784,6 +854,17 @@ class TFGPTJForCausalLM(TFGPTJPreTrainedModel, TFCausalLanguageModelingLoss):
attentions=transformer_outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "transformer", None) is not None:
with tf.name_scope(self.transformer.name):
self.transformer.build(None)
if getattr(self, "lm_head", None) is not None:
with tf.name_scope(self.lm_head.name):
self.lm_head.build([None, None, self.config.n_embd])
@add_start_docstrings(
"""
@ -813,6 +894,7 @@ class TFGPTJForSequenceClassification(TFGPTJPreTrainedModel, TFSequenceClassific
kernel_initializer=get_initializer(config.initializer_range),
name="score",
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(GPTJ_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ -906,6 +988,17 @@ class TFGPTJForSequenceClassification(TFGPTJPreTrainedModel, TFSequenceClassific
attentions=transformer_outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "transformer", None) is not None:
with tf.name_scope(self.transformer.name):
self.transformer.build(None)
if getattr(self, "score", None) is not None:
with tf.name_scope(self.score.name):
self.score.build([None, None, self.config.n_embd])
@add_start_docstrings(
"""
@ -924,6 +1017,7 @@ class TFGPTJForQuestionAnswering(TFGPTJPreTrainedModel, TFQuestionAnsweringLoss)
self.qa_outputs = tf.keras.layers.Dense(
self.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(GPTJ_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ -996,3 +1090,14 @@ class TFGPTJForQuestionAnswering(TFGPTJPreTrainedModel, TFQuestionAnsweringLoss)
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "transformer", None) is not None:
with tf.name_scope(self.transformer.name):
self.transformer.build(None)
if getattr(self, "qa_outputs", None) is not None:
with tf.name_scope(self.qa_outputs.name):
self.qa_outputs.build([None, None, self.config.hidden_size])

View File

@ -271,6 +271,7 @@ class TFGroupViTCrossAttentionLayer(tf.keras.layers.Layer):
self.norm2 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="norm2")
self.mlp = TFGroupViTMLP(config, name="mlp")
self.norm_post = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="norm_post")
self.config = config
def call(self, query: tf.Tensor, key: tf.Tensor, training: bool = False) -> tf.Tensor:
x = query
@ -279,6 +280,23 @@ class TFGroupViTCrossAttentionLayer(tf.keras.layers.Layer):
x = self.norm_post(x)
return x
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "attn", None) is not None:
with tf.name_scope(self.attn.name):
self.attn.build(None)
if getattr(self, "norm2", None) is not None:
with tf.name_scope(self.norm2.name):
self.norm2.build([None, None, self.config.hidden_size])
if getattr(self, "mlp", None) is not None:
with tf.name_scope(self.mlp.name):
self.mlp.build(None)
if getattr(self, "norm_post", None) is not None:
with tf.name_scope(self.norm_post.name):
self.norm_post.build([None, None, self.config.hidden_size])
class TFGroupViTAssignAttention(tf.keras.layers.Layer):
def __init__(self, config: GroupViTVisionConfig, **kwargs):
@ -290,6 +308,7 @@ class TFGroupViTAssignAttention(tf.keras.layers.Layer):
self.v_proj = tf.keras.layers.Dense(config.hidden_size, name="v_proj")
self.proj = tf.keras.layers.Dense(config.hidden_size, name="proj")
self.assign_eps = config.assign_eps
self.config = config
def get_attn(self, attn: tf.Tensor, gumbel: bool = True, hard: bool = True, training: bool = False) -> tf.Tensor:
if gumbel and training:
@ -327,6 +346,23 @@ class TFGroupViTAssignAttention(tf.keras.layers.Layer):
return out, soft_attn
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "q_proj", None) is not None:
with tf.name_scope(self.q_proj.name):
self.q_proj.build([None, None, self.config.hidden_size])
if getattr(self, "k_proj", None) is not None:
with tf.name_scope(self.k_proj.name):
self.k_proj.build([None, None, self.config.hidden_size])
if getattr(self, "v_proj", None) is not None:
with tf.name_scope(self.v_proj.name):
self.v_proj.build([None, None, self.config.hidden_size])
if getattr(self, "proj", None) is not None:
with tf.name_scope(self.proj.name):
self.proj.build([None, None, self.config.hidden_size])
class TFGroupViTTokenAssign(tf.keras.layers.Layer):
def __init__(self, config: GroupViTVisionConfig, num_group_token: int, num_output_group: int, **kwargs):
@ -353,6 +389,7 @@ class TFGroupViTTokenAssign(tf.keras.layers.Layer):
self.mlp_channels = TFGroupViTMLP(
config, config.hidden_size, channels_dim, config.hidden_size, name="mlp_channels"
)
self.config = config
def project_group_token(self, group_tokens: tf.Tensor) -> tf.Tensor:
"""
@ -386,6 +423,35 @@ class TFGroupViTTokenAssign(tf.keras.layers.Layer):
return new_image_tokens, attention
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "norm_tokens", None) is not None:
with tf.name_scope(self.norm_tokens.name):
self.norm_tokens.build([None, None, self.config.hidden_size])
if getattr(self, "mlp_inter", None) is not None:
with tf.name_scope(self.mlp_inter.name):
self.mlp_inter.build(None)
if getattr(self, "norm_post_tokens", None) is not None:
with tf.name_scope(self.norm_post_tokens.name):
self.norm_post_tokens.build([None, None, self.config.hidden_size])
if getattr(self, "norm_x", None) is not None:
with tf.name_scope(self.norm_x.name):
self.norm_x.build([None, None, self.config.hidden_size])
if getattr(self, "pre_assign_attn", None) is not None:
with tf.name_scope(self.pre_assign_attn.name):
self.pre_assign_attn.build(None)
if getattr(self, "assign", None) is not None:
with tf.name_scope(self.assign.name):
self.assign.build(None)
if getattr(self, "norm_new_x", None) is not None:
with tf.name_scope(self.norm_new_x.name):
self.norm_new_x.build([None, None, self.config.hidden_size])
if getattr(self, "mlp_channels", None) is not None:
with tf.name_scope(self.mlp_channels.name):
self.mlp_channels.build(None)
# Adapted from transformers.models.vit.modeling_tf_vit.TFViTPatchEmbeddings with ViT->GroupViT
class TFGroupViTPatchEmbeddings(tf.keras.layers.Layer):
@ -457,6 +523,14 @@ class TFGroupViTPatchEmbeddings(tf.keras.layers.Layer):
return embeddings
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "projection", None) is not None:
with tf.name_scope(self.projection.name):
self.projection.build([None, None, None, self.num_channels])
# Adapted from transformers.vit.modeling_tf_vit.TFViTEmbeddings
class TFGroupViTVisionEmbeddings(tf.keras.layers.Layer):
@ -473,7 +547,7 @@ class TFGroupViTVisionEmbeddings(tf.keras.layers.Layer):
self.layernorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm")
self.config = config
def build(self, input_shape: tf.TensorShape):
def build(self, input_shape=None):
num_patches = self.patch_embeddings.num_patches
self.position_embeddings = self.add_weight(
shape=(1, num_patches, self.config.hidden_size),
@ -482,7 +556,18 @@ class TFGroupViTVisionEmbeddings(tf.keras.layers.Layer):
name="position_embeddings",
)
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "patch_embeddings", None) is not None:
with tf.name_scope(self.patch_embeddings.name):
self.patch_embeddings.build(None)
if getattr(self, "dropout", None) is not None:
with tf.name_scope(self.dropout.name):
self.dropout.build(None)
if getattr(self, "layernorm", None) is not None:
with tf.name_scope(self.layernorm.name):
self.layernorm.build([None, None, self.config.hidden_size])
def interpolate_pos_encoding(self, embeddings, height, width) -> tf.Tensor:
"""
@ -626,7 +711,7 @@ class TFGroupViTStage(tf.keras.layers.Layer):
else:
self.group_projector = None
def build(self, input_shape: tf.TensorShape):
def build(self, input_shape=None):
if self.num_group_token > 0:
self.group_token = self.add_weight(
shape=(1, self.num_group_token, self.config.hidden_size),
@ -636,7 +721,22 @@ class TFGroupViTStage(tf.keras.layers.Layer):
)
else:
self.group_token = None
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "downsample", None) is not None:
with tf.name_scope(self.downsample.name):
self.downsample.build(None)
if getattr(self, "layers", None) is not None:
for layer in self.layers:
with tf.name_scope(layer.name):
layer.build(None)
if getattr(self, "group_projector", None) is not None:
with tf.name_scope(self.group_projector[0].name):
self.group_projector[0].build([None, None, self.config.hidden_size])
with tf.name_scope(self.group_projector[1].name):
self.group_projector[1].build(None)
@property
def with_group_token(self):
@ -720,6 +820,8 @@ class TFGroupViTMLP(tf.keras.layers.Layer):
output_size = output_size if output_size is not None else hidden_size
self.fc1 = tf.keras.layers.Dense(intermediate_size, name="fc1")
self.fc2 = tf.keras.layers.Dense(output_size, name="fc2")
self.intermediate_size = intermediate_size
self.hidden_size = hidden_size
def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.fc1(hidden_states)
@ -727,6 +829,17 @@ class TFGroupViTMLP(tf.keras.layers.Layer):
hidden_states = self.fc2(hidden_states)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "fc1", None) is not None:
with tf.name_scope(self.fc1.name):
self.fc1.build([None, None, self.hidden_size])
if getattr(self, "fc2", None) is not None:
with tf.name_scope(self.fc2.name):
self.fc2.build([None, None, self.intermediate_size])
class TFGroupViTMixerMLP(TFGroupViTMLP):
def call(self, x, training: bool = False):
@ -841,6 +954,23 @@ class TFGroupViTAttention(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "q_proj", None) is not None:
with tf.name_scope(self.q_proj.name):
self.q_proj.build([None, None, self.embed_dim])
if getattr(self, "k_proj", None) is not None:
with tf.name_scope(self.k_proj.name):
self.k_proj.build([None, None, self.embed_dim])
if getattr(self, "v_proj", None) is not None:
with tf.name_scope(self.v_proj.name):
self.v_proj.build([None, None, self.embed_dim])
if getattr(self, "out_proj", None) is not None:
with tf.name_scope(self.out_proj.name):
self.out_proj.build([None, None, self.embed_dim])
# Copied from transformers.models.clip.modeling_tf_clip.TFCLIPEncoderLayer with CLIP->GroupViT
class TFGroupViTEncoderLayer(tf.keras.layers.Layer):
@ -894,6 +1024,23 @@ class TFGroupViTEncoderLayer(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "self_attn", None) is not None:
with tf.name_scope(self.self_attn.name):
self.self_attn.build(None)
if getattr(self, "layer_norm1", None) is not None:
with tf.name_scope(self.layer_norm1.name):
self.layer_norm1.build([None, None, self.embed_dim])
if getattr(self, "mlp", None) is not None:
with tf.name_scope(self.mlp.name):
self.mlp.build(None)
if getattr(self, "layer_norm2", None) is not None:
with tf.name_scope(self.layer_norm2.name):
self.layer_norm2.build([None, None, self.embed_dim])
# Adapted from transformers.models.clip.modeling_tf_clip.TFGroupViTTextEncoder
class TFGroupViTTextEncoder(tf.keras.layers.Layer):
@ -939,6 +1086,15 @@ class TFGroupViTTextEncoder(tf.keras.layers.Layer):
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "layers", None) is not None:
for layer in self.layers:
with tf.name_scope(layer.name):
layer.build(None)
class TFGroupViTVisionEncoder(tf.keras.layers.Layer):
def __init__(self, config: GroupViTVisionConfig, **kwargs) -> None:
@ -990,6 +1146,15 @@ class TFGroupViTVisionEncoder(tf.keras.layers.Layer):
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_groupings
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "stages", None) is not None:
for layer in self.stages:
with tf.name_scope(layer.name):
layer.build(None)
# Copied from transformers.models.clip.modeling_tf_clip.TFCLIPTextTransformer with CLIPText->GroupViTText, CLIPEncoder->GroupViTTextEncoder
class TFGroupViTTextTransformer(tf.keras.layers.Layer):
@ -1004,6 +1169,7 @@ class TFGroupViTTextTransformer(tf.keras.layers.Layer):
# For `pooled_output` computation
self.eos_token_id = config.eos_token_id
self.embed_dim = config.hidden_size
def call(
self,
@ -1094,6 +1260,20 @@ class TFGroupViTTextTransformer(tf.keras.layers.Layer):
return tf.broadcast_to(input=to_mask, shape=(batch_size, 1, seq_length, seq_length))
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embeddings", None) is not None:
with tf.name_scope(self.embeddings.name):
self.embeddings.build(None)
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
if getattr(self, "final_layer_norm", None) is not None:
with tf.name_scope(self.final_layer_norm.name):
self.final_layer_norm.build([None, None, self.embed_dim])
# Adapted from transformers.models.clip.modeling_tf_clip.TFCLIPVisionTransformer
class TFGroupViTVisionTransformer(tf.keras.layers.Layer):
@ -1103,6 +1283,7 @@ class TFGroupViTVisionTransformer(tf.keras.layers.Layer):
self.embeddings = TFGroupViTVisionEmbeddings(config, name="embeddings")
self.encoder = TFGroupViTVisionEncoder(config, name="encoder")
self.layernorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm")
self.embed_dim = config.hidden_size
def call(
self,
@ -1137,6 +1318,20 @@ class TFGroupViTVisionTransformer(tf.keras.layers.Layer):
attentions=encoder_outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embeddings", None) is not None:
with tf.name_scope(self.embeddings.name):
self.embeddings.build(None)
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
if getattr(self, "layernorm", None) is not None:
with tf.name_scope(self.layernorm.name):
self.layernorm.build([None, None, self.embed_dim])
@keras_serializable
# Copied from transformers.models.clip.modeling_tf_clip.TFCLIPTextMainLayer with CLIP->GroupViT
@ -1186,6 +1381,14 @@ class TFGroupViTTextMainLayer(tf.keras.layers.Layer):
return text_model_outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "text_model", None) is not None:
with tf.name_scope(self.text_model.name):
self.text_model.build(None)
@keras_serializable
# Copied from transformers.models.clip.modeling_tf_clip.TFCLIPVisionMainLayer with CLIP->GroupViT
@ -1222,6 +1425,14 @@ class TFGroupViTVisionMainLayer(tf.keras.layers.Layer):
return vision_model_outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "vision_model", None) is not None:
with tf.name_scope(self.vision_model.name):
self.vision_model.build(None)
@keras_serializable
# Adapted from transformers.models.clip.modeling_tf_clip.TFCLIPMainLayer
@ -1269,7 +1480,7 @@ class TFGroupViTMainLayer(tf.keras.layers.Layer):
tf.keras.layers.Dense(self.projection_dim, name="text_projection.3"),
]
def build(self, input_shape: tf.TensorShape):
def build(self, input_shape=None):
self.logit_scale = self.add_weight(
shape=(1,),
initializer=tf.keras.initializers.Constant(self.config.logit_scale_init_value),
@ -1277,7 +1488,29 @@ class TFGroupViTMainLayer(tf.keras.layers.Layer):
name="logit_scale",
)
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "text_model", None) is not None:
with tf.name_scope(self.text_model.name):
self.text_model.build(None)
if getattr(self, "vision_model", None) is not None:
with tf.name_scope(self.vision_model.name):
self.vision_model.build(None)
if getattr(self, "visual_projection", None) is not None:
with tf.name_scope(self.visual_projection[0].name):
self.visual_projection[0].build([None, None, None, self.vision_embed_dim])
with tf.name_scope(self.visual_projection[1].name):
self.visual_projection[1].build((None, self.projection_intermediate_dim))
with tf.name_scope(self.visual_projection[3].name):
self.visual_projection[3].build([None, None, None, self.projection_intermediate_dim])
if getattr(self, "text_projection", None) is not None:
with tf.name_scope(self.text_projection[0].name):
self.text_projection[0].build([None, None, None, self.text_embed_dim])
with tf.name_scope(self.text_projection[1].name):
self.text_projection[1].build((None, self.projection_intermediate_dim))
with tf.name_scope(self.text_projection[3].name):
self.text_projection[3].build([None, None, None, self.projection_intermediate_dim])
@unpack_inputs
def get_text_features(
@ -1669,6 +1902,14 @@ class TFGroupViTTextModel(TFGroupViTPreTrainedModel):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "groupvit", None) is not None:
with tf.name_scope(self.groupvit.name):
self.groupvit.build(None)
class TFGroupViTVisionModel(TFGroupViTPreTrainedModel):
config_class = GroupViTVisionConfig
@ -1723,6 +1964,14 @@ class TFGroupViTVisionModel(TFGroupViTPreTrainedModel):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "groupvit", None) is not None:
with tf.name_scope(self.groupvit.name):
self.groupvit.build(None)
@add_start_docstrings(GROUPVIT_START_DOCSTRING)
class TFGroupViTModel(TFGroupViTPreTrainedModel):
@ -1879,3 +2128,11 @@ class TFGroupViTModel(TFGroupViTPreTrainedModel):
# TensorFlow cannot trace through nested dataclasses. Reference:
# https://github.com/huggingface/transformers/pull/16886
return output
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "groupvit", None) is not None:
with tf.name_scope(self.groupvit.name):
self.groupvit.build(None)

View File

@ -416,11 +416,6 @@ class TFHubertWeightNormConv1D(tf.keras.layers.Conv1D):
def build(self, input_shape):
if not self.built:
input_shape = input_shape.as_list()
# If a specific input shape is passed in, we need to modify it to account for padding
# Not necessary if those portions of the shape are None
if input_shape[-2] is not None:
input_shape[-2] += self.explicit_padding * 2
super().build(input_shape)
self.kernel = tf.Variable(tf.transpose(self.kernel), name="weight_v", trainable=True)
@ -469,6 +464,14 @@ class TFHubertNoLayerNormConvLayer(tf.keras.layers.Layer):
hidden_states = self.activation(hidden_states)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "conv", None) is not None:
with tf.name_scope(self.conv.name):
self.conv.build([None, None, self.in_conv_dim])
# Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2.TFWav2Vec2LayerNormConvLayer with Wav2Vec2->Hubert
class TFHubertLayerNormConvLayer(tf.keras.layers.Layer):
@ -493,6 +496,17 @@ class TFHubertLayerNormConvLayer(tf.keras.layers.Layer):
hidden_states = self.activation(hidden_states)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "conv", None) is not None:
with tf.name_scope(self.conv.name):
self.conv.build([None, None, self.in_conv_dim])
if getattr(self, "layer_norm", None) is not None:
with tf.name_scope(self.layer_norm.name):
self.layer_norm.build([None, None, self.out_conv_dim])
# Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2.TFWav2Vec2GroupNormConvLayer with Wav2Vec2->Hubert
class TFHubertGroupNormConvLayer(tf.keras.layers.Layer):
@ -517,6 +531,17 @@ class TFHubertGroupNormConvLayer(tf.keras.layers.Layer):
hidden_states = self.activation(hidden_states)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "conv", None) is not None:
with tf.name_scope(self.conv.name):
self.conv.build([None, None, self.in_conv_dim])
if getattr(self, "layer_norm", None) is not None:
with tf.name_scope(self.layer_norm.name):
self.layer_norm.build([None, None, self.out_conv_dim])
# Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2.TFWav2Vec2PositionalConvEmbedding with Wav2Vec2->Hubert
class TFHubertPositionalConvEmbedding(tf.keras.layers.Layer):
@ -531,6 +556,7 @@ class TFHubertPositionalConvEmbedding(tf.keras.layers.Layer):
)
self.padding = TFHubertSamePadLayer(config.num_conv_pos_embeddings)
self.activation = get_tf_activation(config.feat_extract_activation)
self.config = config
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.conv(hidden_states)
@ -538,6 +564,14 @@ class TFHubertPositionalConvEmbedding(tf.keras.layers.Layer):
hidden_states = self.activation(hidden_states)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "conv", None) is not None:
with tf.name_scope(self.conv.name):
self.conv.build([None, None, self.config.hidden_size])
# Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2.TFWav2Vec2SamePadLayer with Wav2Vec2->Hubert
class TFHubertSamePadLayer(tf.keras.layers.Layer):
@ -577,6 +611,14 @@ class TFHubertFeatureEncoder(tf.keras.layers.Layer):
hidden_states = conv_layer(hidden_states)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
for conv_layer in self.conv_layers:
with tf.name_scope(conv_layer.name):
conv_layer.build(None)
class TFHubertFeatureExtractor(TFHubertFeatureEncoder):
def __init__(self, config, **kwargs):
@ -601,6 +643,7 @@ class TFHubertFeatureProjection(tf.keras.layers.Layer):
name="projection",
)
self.dropout = tf.keras.layers.Dropout(rate=config.feat_proj_dropout)
self.config = config
def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.layer_norm(hidden_states)
@ -608,6 +651,17 @@ class TFHubertFeatureProjection(tf.keras.layers.Layer):
hidden_states = self.dropout(hidden_states, training=training)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "layer_norm", None) is not None:
with tf.name_scope(self.layer_norm.name):
self.layer_norm.build([None, None, self.config.conv_dim[-1]])
if getattr(self, "projection", None) is not None:
with tf.name_scope(self.projection.name):
self.projection.build([None, None, self.config.conv_dim[-1]])
# Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with TFBart->TFHubert
class TFHubertAttention(tf.keras.layers.Layer):
@ -762,6 +816,23 @@ class TFHubertAttention(tf.keras.layers.Layer):
return attn_output, attn_weights, past_key_value
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "k_proj", None) is not None:
with tf.name_scope(self.k_proj.name):
self.k_proj.build([None, None, self.embed_dim])
if getattr(self, "q_proj", None) is not None:
with tf.name_scope(self.q_proj.name):
self.q_proj.build([None, None, self.embed_dim])
if getattr(self, "v_proj", None) is not None:
with tf.name_scope(self.v_proj.name):
self.v_proj.build([None, None, self.embed_dim])
if getattr(self, "out_proj", None) is not None:
with tf.name_scope(self.out_proj.name):
self.out_proj.build([None, None, self.embed_dim])
# Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2.TFWav2Vec2FeedForward with Wav2Vec2->Hubert
class TFHubertFeedForward(tf.keras.layers.Layer):
@ -785,6 +856,7 @@ class TFHubertFeedForward(tf.keras.layers.Layer):
name="output_dense",
)
self.output_dropout = tf.keras.layers.Dropout(config.hidden_dropout)
self.config = config
def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.intermediate_dense(hidden_states)
@ -795,6 +867,17 @@ class TFHubertFeedForward(tf.keras.layers.Layer):
hidden_states = self.output_dropout(hidden_states, training=training)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "intermediate_dense", None) is not None:
with tf.name_scope(self.intermediate_dense.name):
self.intermediate_dense.build([None, None, self.config.hidden_size])
if getattr(self, "output_dense", None) is not None:
with tf.name_scope(self.output_dense.name):
self.output_dense.build([None, None, self.config.intermediate_size])
# Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2.TFWav2Vec2EncoderLayer with Wav2Vec2->Hubert
class TFHubertEncoderLayer(tf.keras.layers.Layer):
@ -813,6 +896,7 @@ class TFHubertEncoderLayer(tf.keras.layers.Layer):
self.final_layer_norm = tf.keras.layers.LayerNormalization(
epsilon=config.layer_norm_eps, name="final_layer_norm"
)
self.config = config
def call(
self,
@ -839,6 +923,23 @@ class TFHubertEncoderLayer(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "attention", None) is not None:
with tf.name_scope(self.attention.name):
self.attention.build(None)
if getattr(self, "layer_norm", None) is not None:
with tf.name_scope(self.layer_norm.name):
self.layer_norm.build([None, None, self.config.hidden_size])
if getattr(self, "feed_forward", None) is not None:
with tf.name_scope(self.feed_forward.name):
self.feed_forward.build(None)
if getattr(self, "final_layer_norm", None) is not None:
with tf.name_scope(self.final_layer_norm.name):
self.final_layer_norm.build([None, None, self.config.hidden_size])
# Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2.TFWav2Vec2EncoderLayerStableLayerNorm with Wav2Vec2->Hubert
class TFHubertEncoderLayerStableLayerNorm(tf.keras.layers.Layer):
@ -857,6 +958,7 @@ class TFHubertEncoderLayerStableLayerNorm(tf.keras.layers.Layer):
self.final_layer_norm = tf.keras.layers.LayerNormalization(
epsilon=config.layer_norm_eps, name="final_layer_norm"
)
self.config = config
def call(
self,
@ -881,6 +983,23 @@ class TFHubertEncoderLayerStableLayerNorm(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "attention", None) is not None:
with tf.name_scope(self.attention.name):
self.attention.build(None)
if getattr(self, "layer_norm", None) is not None:
with tf.name_scope(self.layer_norm.name):
self.layer_norm.build([None, None, self.config.hidden_size])
if getattr(self, "feed_forward", None) is not None:
with tf.name_scope(self.feed_forward.name):
self.feed_forward.build(None)
if getattr(self, "final_layer_norm", None) is not None:
with tf.name_scope(self.final_layer_norm.name):
self.final_layer_norm.build([None, None, self.config.hidden_size])
# Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2.TFWav2Vec2Encoder with Wav2Vec2->Hubert
class TFHubertEncoder(tf.keras.layers.Layer):
@ -947,6 +1066,21 @@ class TFHubertEncoder(tf.keras.layers.Layer):
attentions=all_self_attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "pos_conv_embed", None) is not None:
with tf.name_scope(self.pos_conv_embed.name):
self.pos_conv_embed.build(None)
if getattr(self, "layer_norm", None) is not None:
with tf.name_scope(self.layer_norm.name):
self.layer_norm.build([None, None, self.config.hidden_size])
if getattr(self, "layer", None) is not None:
for layer in self.layer:
with tf.name_scope(layer.name):
layer.build(None)
# Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2.TFWav2Vec2EncoderStableLayerNorm with Wav2Vec2->Hubert
class TFHubertEncoderStableLayerNorm(tf.keras.layers.Layer):
@ -1015,6 +1149,21 @@ class TFHubertEncoderStableLayerNorm(tf.keras.layers.Layer):
attentions=all_self_attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "pos_conv_embed", None) is not None:
with tf.name_scope(self.pos_conv_embed.name):
self.pos_conv_embed.build(None)
if getattr(self, "layer_norm", None) is not None:
with tf.name_scope(self.layer_norm.name):
self.layer_norm.build([None, None, self.config.hidden_size])
if getattr(self, "layer", None) is not None:
for layer in self.layer:
with tf.name_scope(layer.name):
layer.build(None)
@keras_serializable
class TFHubertMainLayer(tf.keras.layers.Layer):
@ -1031,12 +1180,23 @@ class TFHubertMainLayer(tf.keras.layers.Layer):
else:
self.encoder = TFHubertEncoder(config, name="encoder")
def build(self, input_shape: tf.TensorShape):
def build(self, input_shape=None):
self.masked_spec_embed = self.add_weight(
shape=(self.config.hidden_size,), initializer="uniform", trainable=True, name="masked_spec_embed"
)
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "feature_extractor", None) is not None:
with tf.name_scope(self.feature_extractor.name):
self.feature_extractor.build(None)
if getattr(self, "feature_projection", None) is not None:
with tf.name_scope(self.feature_projection.name):
self.feature_projection.build(None)
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
def _get_feat_extract_output_lengths(self, input_lengths: tf.Tensor):
"""
@ -1345,6 +1505,14 @@ class TFHubertModel(TFHubertPreTrainedModel):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "hubert", None) is not None:
with tf.name_scope(self.hubert.name):
self.hubert.build(None)
@add_start_docstrings(
"""TFHubert Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
@ -1357,6 +1525,9 @@ class TFHubertForCTC(TFHubertPreTrainedModel):
self.hubert = TFHubertMainLayer(config, name="hubert")
self.dropout = tf.keras.layers.Dropout(config.final_dropout)
self.lm_head = tf.keras.layers.Dense(config.vocab_size, name="lm_head")
self.output_hidden_size = (
config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
)
def freeze_feature_extractor(self):
"""
@ -1497,3 +1668,14 @@ class TFHubertForCTC(TFHubertPreTrainedModel):
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "hubert", None) is not None:
with tf.name_scope(self.hubert.name):
self.hubert.build(None)
if getattr(self, "lm_head", None) is not None:
with tf.name_scope(self.lm_head.name):
self.lm_head.build([None, None, self.output_hidden_size])

View File

@ -73,7 +73,7 @@ class TFLayoutLMEmbeddings(tf.keras.layers.Layer):
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
def build(self, input_shape: tf.TensorShape):
def build(self, input_shape=None):
with tf.name_scope("word_embeddings"):
self.weight = self.add_weight(
name="weight",
@ -123,7 +123,12 @@ class TFLayoutLMEmbeddings(tf.keras.layers.Layer):
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
def call(
self,
@ -216,6 +221,7 @@ class TFLayoutLMSelfAttention(tf.keras.layers.Layer):
self.dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
self.is_decoder = config.is_decoder
self.config = config
def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
@ -305,6 +311,20 @@ class TFLayoutLMSelfAttention(tf.keras.layers.Layer):
outputs = outputs + (past_key_value,)
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "query", None) is not None:
with tf.name_scope(self.query.name):
self.query.build([None, None, self.config.hidden_size])
if getattr(self, "key", None) is not None:
with tf.name_scope(self.key.name):
self.key.build([None, None, self.config.hidden_size])
if getattr(self, "value", None) is not None:
with tf.name_scope(self.value.name):
self.value.build([None, None, self.config.hidden_size])
# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfOutput with Bert->LayoutLM
class TFLayoutLMSelfOutput(tf.keras.layers.Layer):
@ -316,6 +336,7 @@ class TFLayoutLMSelfOutput(tf.keras.layers.Layer):
)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
self.config = config
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -324,6 +345,17 @@ class TFLayoutLMSelfOutput(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
# Copied from transformers.models.bert.modeling_tf_bert.TFBertAttention with Bert->LayoutLM
class TFLayoutLMAttention(tf.keras.layers.Layer):
@ -365,6 +397,17 @@ class TFLayoutLMAttention(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "self_attention", None) is not None:
with tf.name_scope(self.self_attention.name):
self.self_attention.build(None)
if getattr(self, "dense_output", None) is not None:
with tf.name_scope(self.dense_output.name):
self.dense_output.build(None)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertIntermediate with Bert->LayoutLM
class TFLayoutLMIntermediate(tf.keras.layers.Layer):
@ -379,6 +422,7 @@ class TFLayoutLMIntermediate(tf.keras.layers.Layer):
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
else:
self.intermediate_act_fn = config.hidden_act
self.config = config
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -386,6 +430,14 @@ class TFLayoutLMIntermediate(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
# Copied from transformers.models.bert.modeling_tf_bert.TFBertOutput with Bert->LayoutLM
class TFLayoutLMOutput(tf.keras.layers.Layer):
@ -397,6 +449,7 @@ class TFLayoutLMOutput(tf.keras.layers.Layer):
)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
self.config = config
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -405,6 +458,17 @@ class TFLayoutLMOutput(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.intermediate_size])
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLayer with Bert->LayoutLM
class TFLayoutLMLayer(tf.keras.layers.Layer):
@ -492,6 +556,23 @@ class TFLayoutLMLayer(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "attention", None) is not None:
with tf.name_scope(self.attention.name):
self.attention.build(None)
if getattr(self, "intermediate", None) is not None:
with tf.name_scope(self.intermediate.name):
self.intermediate.build(None)
if getattr(self, "bert_output", None) is not None:
with tf.name_scope(self.bert_output.name):
self.bert_output.build(None)
if getattr(self, "crossattention", None) is not None:
with tf.name_scope(self.crossattention.name):
self.crossattention.build(None)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertEncoder with Bert->LayoutLM
class TFLayoutLMEncoder(tf.keras.layers.Layer):
@ -562,6 +643,15 @@ class TFLayoutLMEncoder(tf.keras.layers.Layer):
cross_attentions=all_cross_attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "layer", None) is not None:
for layer in self.layer:
with tf.name_scope(layer.name):
layer.build(None)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertPooler with Bert->LayoutLM
class TFLayoutLMPooler(tf.keras.layers.Layer):
@ -574,6 +664,7 @@ class TFLayoutLMPooler(tf.keras.layers.Layer):
activation="tanh",
name="dense",
)
self.config = config
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
# We "pool" the model by simply taking the hidden state corresponding
@ -583,6 +674,14 @@ class TFLayoutLMPooler(tf.keras.layers.Layer):
return pooled_output
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
# Copied from transformers.models.bert.modeling_tf_bert.TFBertPredictionHeadTransform with Bert->LayoutLM
class TFLayoutLMPredictionHeadTransform(tf.keras.layers.Layer):
@ -601,6 +700,7 @@ class TFLayoutLMPredictionHeadTransform(tf.keras.layers.Layer):
self.transform_act_fn = config.hidden_act
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.config = config
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -609,6 +709,17 @@ class TFLayoutLMPredictionHeadTransform(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMPredictionHead with Bert->LayoutLM
class TFLayoutLMLMPredictionHead(tf.keras.layers.Layer):
@ -624,10 +735,15 @@ class TFLayoutLMLMPredictionHead(tf.keras.layers.Layer):
# an output-only bias for each token.
self.input_embeddings = input_embeddings
def build(self, input_shape: tf.TensorShape):
def build(self, input_shape=None):
self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias")
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "transform", None) is not None:
with tf.name_scope(self.transform.name):
self.transform.build(None)
def get_output_embeddings(self) -> tf.keras.layers.Layer:
return self.input_embeddings
@ -666,6 +782,14 @@ class TFLayoutLMMLMHead(tf.keras.layers.Layer):
return prediction_scores
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "predictions", None) is not None:
with tf.name_scope(self.predictions.name):
self.predictions.build(None)
@keras_serializable
class TFLayoutLMMainLayer(tf.keras.layers.Layer):
@ -796,6 +920,20 @@ class TFLayoutLMMainLayer(tf.keras.layers.Layer):
cross_attentions=encoder_outputs.cross_attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embeddings", None) is not None:
with tf.name_scope(self.embeddings.name):
self.embeddings.build(None)
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
if getattr(self, "pooler", None) is not None:
with tf.name_scope(self.pooler.name):
self.pooler.build(None)
class TFLayoutLMPreTrainedModel(TFPreTrainedModel):
"""
@ -986,6 +1124,14 @@ class TFLayoutLMModel(TFLayoutLMPreTrainedModel):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "layoutlm", None) is not None:
with tf.name_scope(self.layoutlm.name):
self.layoutlm.build(None)
@add_start_docstrings("""LayoutLM Model with a `language modeling` head on top.""", LAYOUTLM_START_DOCSTRING)
class TFLayoutLMForMaskedLM(TFLayoutLMPreTrainedModel, TFMaskedLanguageModelingLoss):
@ -1107,6 +1253,17 @@ class TFLayoutLMForMaskedLM(TFLayoutLMPreTrainedModel, TFMaskedLanguageModelingL
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "layoutlm", None) is not None:
with tf.name_scope(self.layoutlm.name):
self.layoutlm.build(None)
if getattr(self, "mlm", None) is not None:
with tf.name_scope(self.mlm.name):
self.mlm.build(None)
@add_start_docstrings(
"""
@ -1132,6 +1289,7 @@ class TFLayoutLMForSequenceClassification(TFLayoutLMPreTrainedModel, TFSequenceC
kernel_initializer=get_initializer(config.initializer_range),
name="classifier",
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ -1225,6 +1383,17 @@ class TFLayoutLMForSequenceClassification(TFLayoutLMPreTrainedModel, TFSequenceC
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "layoutlm", None) is not None:
with tf.name_scope(self.layoutlm.name):
self.layoutlm.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.hidden_size])
@add_start_docstrings(
"""
@ -1256,6 +1425,7 @@ class TFLayoutLMForTokenClassification(TFLayoutLMPreTrainedModel, TFTokenClassif
kernel_initializer=get_initializer(config.initializer_range),
name="classifier",
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ -1347,6 +1517,17 @@ class TFLayoutLMForTokenClassification(TFLayoutLMPreTrainedModel, TFTokenClassif
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "layoutlm", None) is not None:
with tf.name_scope(self.layoutlm.name):
self.layoutlm.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.hidden_size])
@add_start_docstrings(
"""
@ -1376,6 +1557,7 @@ class TFLayoutLMForQuestionAnswering(TFLayoutLMPreTrainedModel, TFQuestionAnswer
kernel_initializer=get_initializer(config.initializer_range),
name="qa_outputs",
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ -1485,3 +1667,14 @@ class TFLayoutLMForQuestionAnswering(TFLayoutLMPreTrainedModel, TFQuestionAnswer
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "layoutlm", None) is not None:
with tf.name_scope(self.layoutlm.name):
self.layoutlm.build(None)
if getattr(self, "qa_outputs", None) is not None:
with tf.name_scope(self.qa_outputs.name):
self.qa_outputs.build([None, None, self.config.hidden_size])

View File

@ -87,6 +87,7 @@ class TFLayoutLMv3PatchEmbeddings(tf.keras.layers.Layer):
)
self.hidden_size = config.hidden_size
self.num_patches = (config.input_size**2) // (patch_sizes[0] * patch_sizes[1])
self.config = config
def call(self, pixel_values: tf.Tensor) -> tf.Tensor:
# When running on CPU, `tf.keras.layers.Conv2D` doesn't support `NCHW` format.
@ -97,6 +98,14 @@ class TFLayoutLMv3PatchEmbeddings(tf.keras.layers.Layer):
embeddings = tf.reshape(embeddings, (-1, self.num_patches, self.hidden_size))
return embeddings
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "proj", None) is not None:
with tf.name_scope(self.proj.name):
self.proj.build([None, None, None, self.config.num_channels])
class TFLayoutLMv3TextEmbeddings(tf.keras.layers.Layer):
"""
@ -151,6 +160,7 @@ class TFLayoutLMv3TextEmbeddings(tf.keras.layers.Layer):
name="w_position_embeddings",
)
self.max_2d_positions = config.max_2d_position_embeddings
self.config = config
def calculate_spatial_position_embeddings(self, bbox: tf.Tensor) -> tf.Tensor:
try:
@ -260,6 +270,35 @@ class TFLayoutLMv3TextEmbeddings(tf.keras.layers.Layer):
embeddings = self.dropout(embeddings, training=training)
return embeddings
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "word_embeddings", None) is not None:
with tf.name_scope(self.word_embeddings.name):
self.word_embeddings.build(None)
if getattr(self, "token_type_embeddings", None) is not None:
with tf.name_scope(self.token_type_embeddings.name):
self.token_type_embeddings.build(None)
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
if getattr(self, "position_embeddings", None) is not None:
with tf.name_scope(self.position_embeddings.name):
self.position_embeddings.build(None)
if getattr(self, "x_position_embeddings", None) is not None:
with tf.name_scope(self.x_position_embeddings.name):
self.x_position_embeddings.build(None)
if getattr(self, "y_position_embeddings", None) is not None:
with tf.name_scope(self.y_position_embeddings.name):
self.y_position_embeddings.build(None)
if getattr(self, "h_position_embeddings", None) is not None:
with tf.name_scope(self.h_position_embeddings.name):
self.h_position_embeddings.build(None)
if getattr(self, "w_position_embeddings", None) is not None:
with tf.name_scope(self.w_position_embeddings.name):
self.w_position_embeddings.build(None)
class TFLayoutLMv3SelfAttention(tf.keras.layers.Layer):
def __init__(self, config: LayoutLMv3Config, **kwargs):
@ -294,6 +333,7 @@ class TFLayoutLMv3SelfAttention(tf.keras.layers.Layer):
self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)
self.has_relative_attention_bias = config.has_relative_attention_bias
self.has_spatial_attention_bias = config.has_spatial_attention_bias
self.config = config
def transpose_for_scores(self, x: tf.Tensor):
shape = tf.shape(x)
@ -372,6 +412,20 @@ class TFLayoutLMv3SelfAttention(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "query", None) is not None:
with tf.name_scope(self.query.name):
self.query.build([None, None, self.config.hidden_size])
if getattr(self, "key", None) is not None:
with tf.name_scope(self.key.name):
self.key.build([None, None, self.config.hidden_size])
if getattr(self, "value", None) is not None:
with tf.name_scope(self.value.name):
self.value.build([None, None, self.config.hidden_size])
# Copied from models.roberta.modeling_tf_roberta.TFRobertaSelfOutput
class TFLayoutLMv3SelfOutput(tf.keras.layers.Layer):
@ -383,6 +437,7 @@ class TFLayoutLMv3SelfOutput(tf.keras.layers.Layer):
)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
self.config = config
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -391,6 +446,17 @@ class TFLayoutLMv3SelfOutput(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
class TFLayoutLMv3Attention(tf.keras.layers.Layer):
def __init__(self, config: LayoutLMv3Config, **kwargs):
@ -421,6 +487,17 @@ class TFLayoutLMv3Attention(tf.keras.layers.Layer):
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "self_attention", None) is not None:
with tf.name_scope(self.self_attention.name):
self.self_attention.build(None)
if getattr(self, "self_output", None) is not None:
with tf.name_scope(self.self_output.name):
self.self_output.build(None)
# Copied from models.roberta.modeling_tf_bert.TFRobertaIntermediate
class TFLayoutLMv3Intermediate(tf.keras.layers.Layer):
@ -435,6 +512,7 @@ class TFLayoutLMv3Intermediate(tf.keras.layers.Layer):
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
else:
self.intermediate_act_fn = config.hidden_act
self.config = config
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -442,6 +520,14 @@ class TFLayoutLMv3Intermediate(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
# Copied from models.roberta.modeling_tf_bert.TFRobertaOutput
class TFLayoutLMv3Output(tf.keras.layers.Layer):
@ -453,6 +539,7 @@ class TFLayoutLMv3Output(tf.keras.layers.Layer):
)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
self.config = config
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -461,6 +548,17 @@ class TFLayoutLMv3Output(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.intermediate_size])
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
class TFLayoutLMv3Layer(tf.keras.layers.Layer):
def __init__(self, config: LayoutLMv3Config, **kwargs):
@ -495,6 +593,20 @@ class TFLayoutLMv3Layer(tf.keras.layers.Layer):
outputs = (layer_output,) + outputs
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "attention", None) is not None:
with tf.name_scope(self.attention.name):
self.attention.build(None)
if getattr(self, "intermediate", None) is not None:
with tf.name_scope(self.intermediate.name):
self.intermediate.build(None)
if getattr(self, "bert_output", None) is not None:
with tf.name_scope(self.bert_output.name):
self.bert_output.build(None)
class TFLayoutLMv3Encoder(tf.keras.layers.Layer):
def __init__(self, config: LayoutLMv3Config, **kwargs):
@ -650,6 +762,24 @@ class TFLayoutLMv3Encoder(tf.keras.layers.Layer):
value for value in [hidden_states, all_hidden_states, all_self_attentions] if value is not None
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "rel_pos_bias", None) is not None:
with tf.name_scope(self.rel_pos_bias.name):
self.rel_pos_bias.build([None, None, self.rel_pos_bins])
if getattr(self, "rel_pos_x_bias", None) is not None:
with tf.name_scope(self.rel_pos_x_bias.name):
self.rel_pos_x_bias.build([None, None, self.rel_2d_pos_bins])
if getattr(self, "rel_pos_y_bias", None) is not None:
with tf.name_scope(self.rel_pos_y_bias.name):
self.rel_pos_y_bias.build([None, None, self.rel_2d_pos_bins])
if getattr(self, "layer", None) is not None:
for layer in self.layer:
with tf.name_scope(layer.name):
layer.build(None)
@keras_serializable
class TFLayoutLMv3MainLayer(tf.keras.layers.Layer):
@ -676,7 +806,7 @@ class TFLayoutLMv3MainLayer(tf.keras.layers.Layer):
self.encoder = TFLayoutLMv3Encoder(config, name="encoder")
def build(self, input_shape: tf.TensorShape):
def build(self, input_shape=None):
if self.config.visual_embed:
image_size = self.config.input_size // self.config.patch_size
self.cls_token = self.add_weight(
@ -694,7 +824,27 @@ class TFLayoutLMv3MainLayer(tf.keras.layers.Layer):
name="pos_embed",
)
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
if getattr(self, "embeddings", None) is not None:
with tf.name_scope(self.embeddings.name):
self.embeddings.build(None)
if getattr(self, "patch_embed", None) is not None:
with tf.name_scope(self.patch_embed.name):
self.patch_embed.build(None)
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
if getattr(self, "dropout", None) is not None:
with tf.name_scope(self.dropout.name):
self.dropout.build(None)
if getattr(self, "norm", None) is not None:
with tf.name_scope(self.norm.name):
self.norm.build([None, None, self.config.hidden_size])
def get_input_embeddings(self) -> tf.keras.layers.Layer:
return self.embeddings.word_embeddings
@ -1180,6 +1330,14 @@ class TFLayoutLMv3Model(TFLayoutLMv3PreTrainedModel):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "layoutlmv3", None) is not None:
with tf.name_scope(self.layoutlmv3.name):
self.layoutlmv3.build(None)
class TFLayoutLMv3ClassificationHead(tf.keras.layers.Layer):
"""
@ -1206,6 +1364,7 @@ class TFLayoutLMv3ClassificationHead(tf.keras.layers.Layer):
kernel_initializer=get_initializer(config.initializer_range),
name="out_proj",
)
self.config = config
def call(self, inputs: tf.Tensor, training: bool = False) -> tf.Tensor:
outputs = self.dropout(inputs, training=training)
@ -1214,6 +1373,20 @@ class TFLayoutLMv3ClassificationHead(tf.keras.layers.Layer):
outputs = self.out_proj(outputs)
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
if getattr(self, "dropout", None) is not None:
with tf.name_scope(self.dropout.name):
self.dropout.build(None)
if getattr(self, "out_proj", None) is not None:
with tf.name_scope(self.out_proj.name):
self.out_proj.build([None, None, self.config.hidden_size])
@add_start_docstrings(
"""
@ -1317,6 +1490,17 @@ class TFLayoutLMv3ForSequenceClassification(TFLayoutLMv3PreTrainedModel, TFSeque
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "layoutlmv3", None) is not None:
with tf.name_scope(self.layoutlmv3.name):
self.layoutlmv3.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build(None)
@add_start_docstrings(
"""
@ -1345,6 +1529,7 @@ class TFLayoutLMv3ForTokenClassification(TFLayoutLMv3PreTrainedModel, TFTokenCla
)
else:
self.classifier = TFLayoutLMv3ClassificationHead(config, name="classifier")
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(LAYOUTLMV3_INPUTS_DOCSTRING)
@ -1440,6 +1625,20 @@ class TFLayoutLMv3ForTokenClassification(TFLayoutLMv3PreTrainedModel, TFTokenCla
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "layoutlmv3", None) is not None:
with tf.name_scope(self.layoutlmv3.name):
self.layoutlmv3.build(None)
if getattr(self, "dropout", None) is not None:
with tf.name_scope(self.dropout.name):
self.dropout.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.hidden_size])
@add_start_docstrings(
"""
@ -1567,3 +1766,14 @@ class TFLayoutLMv3ForQuestionAnswering(TFLayoutLMv3PreTrainedModel, TFQuestionAn
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "layoutlmv3", None) is not None:
with tf.name_scope(self.layoutlmv3.name):
self.layoutlmv3.build(None)
if getattr(self, "qa_outputs", None) is not None:
with tf.name_scope(self.qa_outputs.name):
self.qa_outputs.build(None)

View File

@ -37,7 +37,6 @@ from ...modeling_tf_utils import (
)
from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
from ...utils import (
ContextManagers,
ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
@ -200,7 +199,28 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
self.key_global.build((self.config.hidden_size,))
with tf.name_scope("value_global"):
self.value_global.build((self.config.hidden_size,))
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "query", None) is not None:
with tf.name_scope(self.query.name):
self.query.build([None, None, self.config.hidden_size])
if getattr(self, "key", None) is not None:
with tf.name_scope(self.key.name):
self.key.build([None, None, self.config.hidden_size])
if getattr(self, "value", None) is not None:
with tf.name_scope(self.value.name):
self.value.build([None, None, self.config.hidden_size])
if getattr(self, "query_global", None) is not None:
with tf.name_scope(self.query_global.name):
self.query_global.build([None, None, self.config.hidden_size])
if getattr(self, "key_global", None) is not None:
with tf.name_scope(self.key_global.name):
self.key_global.build([None, None, self.config.hidden_size])
if getattr(self, "value_global", None) is not None:
with tf.name_scope(self.value_global.name):
self.value_global.build([None, None, self.config.hidden_size])
def call(
self,
@ -983,6 +1003,7 @@ class TFLEDEncoderAttention(tf.keras.layers.Layer):
super().__init__(**kwargs)
self.longformer_self_attn = TFLEDEncoderSelfAttention(config, layer_id=layer_id, name="longformer_self_attn")
self.output_dense = tf.keras.layers.Dense(config.d_model, use_bias=True, name="output")
self.config = config
def call(self, inputs, training=False):
(
@ -1004,6 +1025,17 @@ class TFLEDEncoderAttention(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "longformer_self_attn", None) is not None:
with tf.name_scope(self.longformer_self_attn.name):
self.longformer_self_attn.build(None)
if getattr(self, "output_dense", None) is not None:
with tf.name_scope(self.output_dense.name):
self.output_dense.build([None, None, self.config.d_model])
class TFLEDDecoderAttention(tf.keras.layers.Layer):
"""Multi-headed attention from "Attention Is All You Need"""
@ -1155,6 +1187,23 @@ class TFLEDDecoderAttention(tf.keras.layers.Layer):
return attn_output, attn_weights, past_key_value
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "k_proj", None) is not None:
with tf.name_scope(self.k_proj.name):
self.k_proj.build([None, None, self.embed_dim])
if getattr(self, "q_proj", None) is not None:
with tf.name_scope(self.q_proj.name):
self.q_proj.build([None, None, self.embed_dim])
if getattr(self, "v_proj", None) is not None:
with tf.name_scope(self.v_proj.name):
self.v_proj.build([None, None, self.embed_dim])
if getattr(self, "out_proj", None) is not None:
with tf.name_scope(self.out_proj.name):
self.out_proj.build([None, None, self.embed_dim])
class TFLEDEncoderLayer(tf.keras.layers.Layer):
def __init__(self, config: LEDConfig, layer_id: int, **kwargs):
@ -1168,6 +1217,7 @@ class TFLEDEncoderLayer(tf.keras.layers.Layer):
self.fc1 = tf.keras.layers.Dense(config.encoder_ffn_dim, name="fc1")
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
self.config = config
def call(
self,
@ -1214,6 +1264,26 @@ class TFLEDEncoderLayer(tf.keras.layers.Layer):
return (hidden_states,) + layer_outputs[1:]
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "self_attn", None) is not None:
with tf.name_scope(self.self_attn.name):
self.self_attn.build(None)
if getattr(self, "self_attn_layer_norm", None) is not None:
with tf.name_scope(self.self_attn_layer_norm.name):
self.self_attn_layer_norm.build([None, None, self.embed_dim])
if getattr(self, "fc1", None) is not None:
with tf.name_scope(self.fc1.name):
self.fc1.build([None, None, self.embed_dim])
if getattr(self, "fc2", None) is not None:
with tf.name_scope(self.fc2.name):
self.fc2.build([None, None, self.config.encoder_ffn_dim])
if getattr(self, "final_layer_norm", None) is not None:
with tf.name_scope(self.final_layer_norm.name):
self.final_layer_norm.build([None, None, self.embed_dim])
class TFLEDDecoderLayer(tf.keras.layers.Layer):
def __init__(self, config: LEDConfig, **kwargs):
@ -1242,6 +1312,7 @@ class TFLEDDecoderLayer(tf.keras.layers.Layer):
self.fc1 = tf.keras.layers.Dense(config.decoder_ffn_dim, name="fc1")
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
self.config = config
def call(
self,
@ -1323,6 +1394,32 @@ class TFLEDDecoderLayer(tf.keras.layers.Layer):
present_key_value,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "self_attn", None) is not None:
with tf.name_scope(self.self_attn.name):
self.self_attn.build(None)
if getattr(self, "self_attn_layer_norm", None) is not None:
with tf.name_scope(self.self_attn_layer_norm.name):
self.self_attn_layer_norm.build([None, None, self.embed_dim])
if getattr(self, "encoder_attn", None) is not None:
with tf.name_scope(self.encoder_attn.name):
self.encoder_attn.build(None)
if getattr(self, "encoder_attn_layer_norm", None) is not None:
with tf.name_scope(self.encoder_attn_layer_norm.name):
self.encoder_attn_layer_norm.build([None, None, self.embed_dim])
if getattr(self, "fc1", None) is not None:
with tf.name_scope(self.fc1.name):
self.fc1.build([None, None, self.embed_dim])
if getattr(self, "fc2", None) is not None:
with tf.name_scope(self.fc2.name):
self.fc2.build([None, None, self.config.decoder_ffn_dim])
if getattr(self, "final_layer_norm", None) is not None:
with tf.name_scope(self.final_layer_norm.name):
self.final_layer_norm.build([None, None, self.embed_dim])
class TFLEDPreTrainedModel(TFPreTrainedModel):
config_class = LEDConfig
@ -1662,6 +1759,7 @@ class TFLEDEncoder(tf.keras.layers.Layer):
)
self.layers = [TFLEDEncoderLayer(config, i, name=f"layers.{i}") for i in range(config.encoder_layers)]
self.layernorm_embedding = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding")
self.embed_dim = config.d_model
def get_embed_tokens(self):
return self.embed_tokens
@ -1723,16 +1821,8 @@ class TFLEDEncoder(tf.keras.layers.Layer):
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = shape_list(input_ids)
# if `self.embed_tokens.load_weight_prefix` is set, runs the embedding operation with the correct name
# scope, so that its weights are registered with the desired name for loading/storing. When `tf.name_scope`
# is used with a name ending in `/`, that name replaces the current name scope.
# (embeddings with tf.name_scope: self.embed_tokens.load_weight_prefix/self.embed_tokens.name/embeddings:0)
context = []
if hasattr(self.embed_tokens, "load_weight_prefix"):
context.append(tf.name_scope(self.embed_tokens.load_weight_prefix + "/"))
with ContextManagers(context):
check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim)
inputs_embeds = self.embed_tokens(input_ids)
check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim)
inputs_embeds = self.embed_tokens(input_ids)
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
else:
@ -1884,6 +1974,21 @@ class TFLEDEncoder(tf.keras.layers.Layer):
inputs_embeds,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embed_positions", None) is not None:
with tf.name_scope(self.embed_positions.name):
self.embed_positions.build(None)
if getattr(self, "layernorm_embedding", None) is not None:
with tf.name_scope(self.layernorm_embedding.name):
self.layernorm_embedding.build([None, None, self.embed_dim])
if getattr(self, "layers", None) is not None:
for layer in self.layers:
with tf.name_scope(layer.name):
layer.build(None)
@keras_serializable
class TFLEDDecoder(tf.keras.layers.Layer):
@ -2003,16 +2108,8 @@ class TFLEDDecoder(tf.keras.layers.Layer):
positions = self.embed_positions(input_shape, past_key_values_length)
if inputs_embeds is None:
# if `self.embed_tokens.load_weight_prefix` is set, runs the embedding operation with the correct name
# scope, so that its weights are registered with the desired name for loading/storing. When `tf.name_scope`
# is used with a name ending in `/`, that name replaces the current name scope.
# (embeddings with tf.name_scope: self.embed_tokens.load_weight_prefix/self.embed_tokens.name/embeddings:0)
context = []
if hasattr(self.embed_tokens, "load_weight_prefix"):
context.append(tf.name_scope(self.embed_tokens.load_weight_prefix + "/"))
with ContextManagers(context):
check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim)
inputs_embeds = self.embed_tokens(input_ids)
check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim)
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
@ -2104,6 +2201,21 @@ class TFLEDDecoder(tf.keras.layers.Layer):
cross_attentions=all_cross_attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embed_positions", None) is not None:
with tf.name_scope(self.embed_positions.name):
self.embed_positions.build(None)
if getattr(self, "layernorm_embedding", None) is not None:
with tf.name_scope(self.layernorm_embedding.name):
self.layernorm_embedding.build([None, None, self.config.d_model])
if getattr(self, "layers", None) is not None:
for layer in self.layers:
with tf.name_scope(layer.name):
layer.build(None)
@keras_serializable
class TFLEDMainLayer(tf.keras.layers.Layer):
@ -2210,6 +2322,22 @@ class TFLEDMainLayer(tf.keras.layers.Layer):
encoder_global_attentions=encoder_outputs.global_attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
# The shared/tied weights expect to be in the model base namespace
# Adding "/" to the end (not the start!) of a tf.name_scope puts it in the root namespace rather than
# the current one.
with tf.name_scope(self.shared.load_weight_prefix + "/" + self.shared.name + "/"):
self.shared.build(None)
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
if getattr(self, "decoder", None) is not None:
with tf.name_scope(self.decoder.name):
self.decoder.build(None)
@add_start_docstrings(
"The bare LED Model outputting raw hidden-states without any specific head on top.",
@ -2296,6 +2424,14 @@ class TFLEDModel(TFLEDPreTrainedModel):
encoder_global_attentions=enc_g_attns,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "led", None) is not None:
with tf.name_scope(self.led.name):
self.led.build(None)
# Copied from transformers.models.bart.modeling_tf_bart.BiasLayer
class BiasLayer(tf.keras.layers.Layer):
@ -2516,3 +2652,14 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
masked_loss = unmasked_loss * loss_mask
reduced_masked_loss = tf.reduce_sum(masked_loss) / tf.reduce_sum(loss_mask)
return tf.reshape(reduced_masked_loss, (1,))
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "led", None) is not None:
with tf.name_scope(self.led.name):
self.led.build(None)
if getattr(self, "bias_layer", None) is not None:
with tf.name_scope(self.bias_layer.name):
self.bias_layer.build(None)

View File

@ -434,10 +434,18 @@ class TFLongformerLMHead(tf.keras.layers.Layer):
# an output-only bias for each token.
self.decoder = input_embeddings
def build(self, input_shape):
def build(self, input_shape=None):
self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias")
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
if getattr(self, "layer_norm", None) is not None:
with tf.name_scope(self.layer_norm.name):
self.layer_norm.build([None, None, self.config.hidden_size])
def get_output_embeddings(self):
return self.decoder
@ -484,7 +492,7 @@ class TFLongformerEmbeddings(tf.keras.layers.Layer):
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
def build(self, input_shape: tf.TensorShape):
def build(self, input_shape=None):
with tf.name_scope("word_embeddings"):
self.weight = self.add_weight(
name="weight",
@ -506,7 +514,12 @@ class TFLongformerEmbeddings(tf.keras.layers.Layer):
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
def create_position_ids_from_input_ids(self, input_ids, past_key_values_length=0):
"""
@ -582,6 +595,7 @@ class TFLongformerIntermediate(tf.keras.layers.Layer):
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
else:
self.intermediate_act_fn = config.hidden_act
self.config = config
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -589,6 +603,14 @@ class TFLongformerIntermediate(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
# Copied from transformers.models.bert.modeling_tf_bert.TFBertOutput with Bert->Longformer
class TFLongformerOutput(tf.keras.layers.Layer):
@ -600,6 +622,7 @@ class TFLongformerOutput(tf.keras.layers.Layer):
)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
self.config = config
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -608,6 +631,17 @@ class TFLongformerOutput(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.intermediate_size])
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
# Copied from transformers.models.bert.modeling_tf_bert.TFBertPooler with Bert->Longformer
class TFLongformerPooler(tf.keras.layers.Layer):
@ -620,6 +654,7 @@ class TFLongformerPooler(tf.keras.layers.Layer):
activation="tanh",
name="dense",
)
self.config = config
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
# We "pool" the model by simply taking the hidden state corresponding
@ -629,6 +664,14 @@ class TFLongformerPooler(tf.keras.layers.Layer):
return pooled_output
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfOutput with Bert->Longformer
class TFLongformerSelfOutput(tf.keras.layers.Layer):
@ -640,6 +683,7 @@ class TFLongformerSelfOutput(tf.keras.layers.Layer):
)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
self.config = config
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -648,6 +692,17 @@ class TFLongformerSelfOutput(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
class TFLongformerSelfAttention(tf.keras.layers.Layer):
def __init__(self, config, layer_id, **kwargs):
@ -717,7 +772,28 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
self.key_global.build((self.config.hidden_size,))
with tf.name_scope("value_global"):
self.value_global.build((self.config.hidden_size,))
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "query", None) is not None:
with tf.name_scope(self.query.name):
self.query.build([None, None, self.config.hidden_size])
if getattr(self, "key", None) is not None:
with tf.name_scope(self.key.name):
self.key.build([None, None, self.config.hidden_size])
if getattr(self, "value", None) is not None:
with tf.name_scope(self.value.name):
self.value.build([None, None, self.config.hidden_size])
if getattr(self, "query_global", None) is not None:
with tf.name_scope(self.query_global.name):
self.query_global.build([None, None, self.config.hidden_size])
if getattr(self, "key_global", None) is not None:
with tf.name_scope(self.key_global.name):
self.key_global.build([None, None, self.config.hidden_size])
if getattr(self, "value_global", None) is not None:
with tf.name_scope(self.value_global.name):
self.value_global.build([None, None, self.config.hidden_size])
def call(
self,
@ -1524,6 +1600,17 @@ class TFLongformerAttention(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "self_attention", None) is not None:
with tf.name_scope(self.self_attention.name):
self.self_attention.build(None)
if getattr(self, "dense_output", None) is not None:
with tf.name_scope(self.dense_output.name):
self.dense_output.build(None)
class TFLongformerLayer(tf.keras.layers.Layer):
def __init__(self, config, layer_id=0, **kwargs):
@ -1554,6 +1641,20 @@ class TFLongformerLayer(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "attention", None) is not None:
with tf.name_scope(self.attention.name):
self.attention.build(None)
if getattr(self, "intermediate", None) is not None:
with tf.name_scope(self.intermediate.name):
self.intermediate.build(None)
if getattr(self, "longformer_output", None) is not None:
with tf.name_scope(self.longformer_output.name):
self.longformer_output.build(None)
class TFLongformerEncoder(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
@ -1632,6 +1733,15 @@ class TFLongformerEncoder(tf.keras.layers.Layer):
global_attentions=all_global_attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "layer", None) is not None:
for layer in self.layer:
with tf.name_scope(layer.name):
layer.build(None)
@keras_serializable
class TFLongformerMainLayer(tf.keras.layers.Layer):
@ -1859,6 +1969,20 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
return attention_mask
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embeddings", None) is not None:
with tf.name_scope(self.embeddings.name):
self.embeddings.build(None)
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
if getattr(self, "pooler", None) is not None:
with tf.name_scope(self.pooler.name):
self.pooler.build(None)
class TFLongformerPreTrainedModel(TFPreTrainedModel):
"""
@ -2044,6 +2168,14 @@ class TFLongformerModel(TFLongformerPreTrainedModel):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "longformer", None) is not None:
with tf.name_scope(self.longformer.name):
self.longformer.build(None)
@add_start_docstrings(
"""Longformer Model with a `language modeling` head on top.""",
@ -2128,6 +2260,17 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel
global_attentions=outputs.global_attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "longformer", None) is not None:
with tf.name_scope(self.longformer.name):
self.longformer.build(None)
if getattr(self, "lm_head", None) is not None:
with tf.name_scope(self.lm_head.name):
self.lm_head.build(None)
@add_start_docstrings(
"""
@ -2150,6 +2293,7 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
kernel_initializer=get_initializer(config.initializer_range),
name="qa_outputs",
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ -2258,6 +2402,17 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
global_attentions=outputs.global_attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "longformer", None) is not None:
with tf.name_scope(self.longformer.name):
self.longformer.build(None)
if getattr(self, "qa_outputs", None) is not None:
with tf.name_scope(self.qa_outputs.name):
self.qa_outputs.build([None, None, self.config.hidden_size])
class TFLongformerClassificationHead(tf.keras.layers.Layer):
"""Head for sentence-level classification tasks."""
@ -2274,6 +2429,7 @@ class TFLongformerClassificationHead(tf.keras.layers.Layer):
self.out_proj = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="out_proj"
)
self.config = config
def call(self, hidden_states, training=False):
hidden_states = hidden_states[:, 0, :] # take <s> token (equiv. to [CLS])
@ -2283,6 +2439,17 @@ class TFLongformerClassificationHead(tf.keras.layers.Layer):
output = self.out_proj(hidden_states)
return output
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
if getattr(self, "out_proj", None) is not None:
with tf.name_scope(self.out_proj.name):
self.out_proj.build([None, None, self.config.hidden_size])
@add_start_docstrings(
"""
@ -2386,6 +2553,17 @@ class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSeque
global_attentions=outputs.global_attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "longformer", None) is not None:
with tf.name_scope(self.longformer.name):
self.longformer.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build(None)
@add_start_docstrings(
"""
@ -2406,6 +2584,7 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic
self.classifier = tf.keras.layers.Dense(
1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
self.config = config
@property
def input_signature(self):
@ -2500,6 +2679,17 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic
global_attentions=outputs.global_attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "longformer", None) is not None:
with tf.name_scope(self.longformer.name):
self.longformer.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.hidden_size])
@add_start_docstrings(
"""
@ -2522,6 +2712,7 @@ class TFLongformerForTokenClassification(TFLongformerPreTrainedModel, TFTokenCla
self.classifier = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ -2579,3 +2770,14 @@ class TFLongformerForTokenClassification(TFLongformerPreTrainedModel, TFTokenCla
attentions=outputs.attentions,
global_attentions=outputs.global_attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "longformer", None) is not None:
with tf.name_scope(self.longformer.name):
self.longformer.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.hidden_size])

View File

@ -174,6 +174,9 @@ class TFLxmertVisualFeatureEncoder(tf.keras.layers.Layer):
self.box_layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="box_layer_norm")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
self.feat_dim = config.visual_feat_dim
self.pos_dim = config.visual_pos_dim
self.config = config
def call(self, visn_input, training=False):
feats, boxes = visn_input
@ -187,6 +190,23 @@ class TFLxmertVisualFeatureEncoder(tf.keras.layers.Layer):
output = self.dropout(output, training=training)
return output
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "visn_fc", None) is not None:
with tf.name_scope(self.visn_fc.name):
self.visn_fc.build([None, None, self.feat_dim])
if getattr(self, "visn_layer_norm", None) is not None:
with tf.name_scope(self.visn_layer_norm.name):
self.visn_layer_norm.build([None, None, self.config.hidden_size])
if getattr(self, "box_fc", None) is not None:
with tf.name_scope(self.box_fc.name):
self.box_fc.build([None, None, self.pos_dim])
if getattr(self, "box_layer_norm", None) is not None:
with tf.name_scope(self.box_layer_norm.name):
self.box_layer_norm.build([None, None, self.config.hidden_size])
class TFLxmertEmbeddings(tf.keras.layers.Layer):
"""Construct the embeddings from word, position and token_type embeddings."""
@ -201,7 +221,7 @@ class TFLxmertEmbeddings(tf.keras.layers.Layer):
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
def build(self, input_shape):
def build(self, input_shape=None):
with tf.name_scope("word_embeddings"):
self.weight = self.add_weight(
name="weight",
@ -223,7 +243,12 @@ class TFLxmertEmbeddings(tf.keras.layers.Layer):
initializer=get_initializer(initializer_range=self.initializer_range),
)
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
def call(self, input_ids=None, token_type_ids=None, inputs_embeds=None, training=False):
"""
@ -284,6 +309,8 @@ class TFLxmertAttention(tf.keras.layers.Layer):
)
self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)
self.ctx_dim = config.hidden_size
self.config = config
def transpose_for_scores(self, x, batch_size):
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
@ -328,6 +355,20 @@ class TFLxmertAttention(tf.keras.layers.Layer):
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "query", None) is not None:
with tf.name_scope(self.query.name):
self.query.build([None, None, self.config.hidden_size])
if getattr(self, "key", None) is not None:
with tf.name_scope(self.key.name):
self.key.build([None, None, self.ctx_dim])
if getattr(self, "value", None) is not None:
with tf.name_scope(self.value.name):
self.value.build([None, None, self.ctx_dim])
class TFLxmertIntermediate(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
@ -341,12 +382,21 @@ class TFLxmertIntermediate(tf.keras.layers.Layer):
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
else:
self.intermediate_act_fn = config.hidden_act
self.config = config
def call(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
class TFLxmertOutput(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
@ -359,6 +409,7 @@ class TFLxmertOutput(tf.keras.layers.Layer):
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
self.config = config
def call(self, hidden_states, input_tensor, training=False):
hidden_states = self.dense(hidden_states)
@ -366,6 +417,17 @@ class TFLxmertOutput(tf.keras.layers.Layer):
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.intermediate_size])
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
class TFLxmertAttentionOutput(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
@ -377,6 +439,7 @@ class TFLxmertAttentionOutput(tf.keras.layers.Layer):
)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
self.config = config
def call(self, hidden_states, input_tensor, training=False):
hidden_states = self.dense(hidden_states)
@ -384,6 +447,17 @@ class TFLxmertAttentionOutput(tf.keras.layers.Layer):
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
class TFLxmertSelfAttentionLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
@ -399,6 +473,17 @@ class TFLxmertSelfAttentionLayer(tf.keras.layers.Layer):
attention_output = self.attention_output(self_output[0], input_tensor)
return (attention_output, attention_probs) if output_attentions else (attention_output,)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "self", None) is not None:
with tf.name_scope(self.self.name):
self.self.build(None)
if getattr(self, "attention_output", None) is not None:
with tf.name_scope(self.attention_output.name):
self.attention_output.build(None)
class TFLxmertCrossAttentionLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
@ -421,6 +506,17 @@ class TFLxmertCrossAttentionLayer(tf.keras.layers.Layer):
outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "att", None) is not None:
with tf.name_scope(self.att.name):
self.att.build(None)
if getattr(self, "attention_output", None) is not None:
with tf.name_scope(self.attention_output.name):
self.attention_output.build(None)
class TFLxmertLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
@ -437,6 +533,20 @@ class TFLxmertLayer(tf.keras.layers.Layer):
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "attention", None) is not None:
with tf.name_scope(self.attention.name):
self.attention.build(None)
if getattr(self, "intermediate", None) is not None:
with tf.name_scope(self.intermediate.name):
self.intermediate.build(None)
if getattr(self, "transformer_output", None) is not None:
with tf.name_scope(self.transformer_output.name):
self.transformer_output.build(None)
class TFLxmertXLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
@ -542,6 +652,32 @@ class TFLxmertXLayer(tf.keras.layers.Layer):
return (lang_output, visn_output, attention_probs[0]) if output_attentions else (lang_output, visn_output)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "visual_attention", None) is not None:
with tf.name_scope(self.visual_attention.name):
self.visual_attention.build(None)
if getattr(self, "lang_self_att", None) is not None:
with tf.name_scope(self.lang_self_att.name):
self.lang_self_att.build(None)
if getattr(self, "visn_self_att", None) is not None:
with tf.name_scope(self.visn_self_att.name):
self.visn_self_att.build(None)
if getattr(self, "lang_inter", None) is not None:
with tf.name_scope(self.lang_inter.name):
self.lang_inter.build(None)
if getattr(self, "lang_output", None) is not None:
with tf.name_scope(self.lang_output.name):
self.lang_output.build(None)
if getattr(self, "visn_inter", None) is not None:
with tf.name_scope(self.visn_inter.name):
self.visn_inter.build(None)
if getattr(self, "visn_output", None) is not None:
with tf.name_scope(self.visn_output.name):
self.visn_output.build(None)
class TFLxmertEncoder(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
@ -631,6 +767,26 @@ class TFLxmertEncoder(tf.keras.layers.Layer):
cross_encoder_attentions if output_attentions else None,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "visn_fc", None) is not None:
with tf.name_scope(self.visn_fc.name):
self.visn_fc.build(None)
if getattr(self, "layer", None) is not None:
for layer in self.layer:
with tf.name_scope(layer.name):
layer.build(None)
if getattr(self, "x_layers", None) is not None:
for layer in self.x_layers:
with tf.name_scope(layer.name):
layer.build(None)
if getattr(self, "r_layers", None) is not None:
for layer in self.r_layers:
with tf.name_scope(layer.name):
layer.build(None)
@keras_serializable
class TFLxmertMainLayer(tf.keras.layers.Layer):
@ -771,6 +927,20 @@ class TFLxmertMainLayer(tf.keras.layers.Layer):
cross_encoder_attentions=cross_encoder_attentions if output_attentions else None,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embeddings", None) is not None:
with tf.name_scope(self.embeddings.name):
self.embeddings.build(None)
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
if getattr(self, "pooler", None) is not None:
with tf.name_scope(self.pooler.name):
self.pooler.build(None)
class TFLxmertPreTrainedModel(TFPreTrainedModel):
"""
@ -966,6 +1136,14 @@ class TFLxmertModel(TFLxmertPreTrainedModel):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "lxmert", None) is not None:
with tf.name_scope(self.lxmert.name):
self.lxmert.build(None)
class TFLxmertPooler(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
@ -976,6 +1154,7 @@ class TFLxmertPooler(tf.keras.layers.Layer):
activation="tanh",
name="dense",
)
self.config = config
def call(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
@ -984,6 +1163,14 @@ class TFLxmertPooler(tf.keras.layers.Layer):
pooled_output = self.dense(first_token_tensor)
return pooled_output
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
# Copied from transformers.models.bert.modeling_tf_bert.TFBertPredictionHeadTransform with Bert->Lxmert
class TFLxmertPredictionHeadTransform(tf.keras.layers.Layer):
@ -1002,6 +1189,7 @@ class TFLxmertPredictionHeadTransform(tf.keras.layers.Layer):
self.transform_act_fn = config.hidden_act
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.config = config
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -1010,6 +1198,17 @@ class TFLxmertPredictionHeadTransform(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMPredictionHead with Bert->Lxmert
class TFLxmertLMPredictionHead(tf.keras.layers.Layer):
@ -1025,10 +1224,15 @@ class TFLxmertLMPredictionHead(tf.keras.layers.Layer):
# an output-only bias for each token.
self.input_embeddings = input_embeddings
def build(self, input_shape: tf.TensorShape):
def build(self, input_shape=None):
self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias")
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "transform", None) is not None:
with tf.name_scope(self.transform.name):
self.transform.build(None)
def get_output_embeddings(self) -> tf.keras.layers.Layer:
return self.input_embeddings
@ -1067,6 +1271,14 @@ class TFLxmertMLMHead(tf.keras.layers.Layer):
return prediction_scores
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "predictions", None) is not None:
with tf.name_scope(self.predictions.name):
self.predictions.build(None)
class TFLxmertPreTrainingHeads(tf.keras.layers.Layer):
def __init__(self, config, input_embeddings, **kwargs):
@ -1078,12 +1290,24 @@ class TFLxmertPreTrainingHeads(tf.keras.layers.Layer):
kernel_initializer=get_initializer(config.initializer_range),
name="seq_relationship",
)
self.config = config
def call(self, sequence_output, pooled_output):
prediction_scores = self.predictions(sequence_output)
seq_relationship_score = self.seq_relationship(pooled_output)
return prediction_scores, seq_relationship_score
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "predictions", None) is not None:
with tf.name_scope(self.predictions.name):
self.predictions.build(None)
if getattr(self, "seq_relationship", None) is not None:
with tf.name_scope(self.seq_relationship.name):
self.seq_relationship.build([None, None, self.config.hidden_size])
class TFLxmertVisualAnswerHead(tf.keras.layers.Layer):
def __init__(self, config, num_labels, **kwargs):
@ -1101,6 +1325,7 @@ class TFLxmertVisualAnswerHead(tf.keras.layers.Layer):
kernel_initializer=get_initializer(config.initializer_range),
name="logit_fc_._3",
)
self.hid_dim = hid_dim
def call(self, hidden_states):
hidden_states = self.dense(hidden_states)
@ -1110,6 +1335,20 @@ class TFLxmertVisualAnswerHead(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.hid_dim])
if getattr(self, "layer_norm", None) is not None:
with tf.name_scope(self.layer_norm.name):
self.layer_norm.build([None, self.hid_dim * 2])
if getattr(self, "dense_1", None) is not None:
with tf.name_scope(self.dense_1.name):
self.dense_1.build([None, None, self.hid_dim * 2])
class TFLxmertVisualObjHead(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
@ -1136,6 +1375,7 @@ class TFLxmertVisualObjHead(tf.keras.layers.Layer):
)
for key in self.visual_losses
}
self.config = config
def call(self, hidden_states):
hidden_states = self.transform(hidden_states)
@ -1144,6 +1384,18 @@ class TFLxmertVisualObjHead(tf.keras.layers.Layer):
output[key] = self.decoder_dict[key](hidden_states)
return output
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "transform", None) is not None:
with tf.name_scope(self.transform.name):
self.transform.build(None)
if getattr(self, "decoder_dict", None) is not None:
for layer in self.decoder_dict.values():
with tf.name_scope(layer.name):
layer.build([None, None, self.config.hidden_size])
@add_start_docstrings("""Lxmert Model with a `language modeling` head on top.""", LXMERT_START_DOCSTRING)
class TFLxmertForPreTraining(TFLxmertPreTrainedModel):
@ -1387,3 +1639,20 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel):
vision_attentions=lxmert_output.vision_attentions,
cross_encoder_attentions=lxmert_output.cross_encoder_attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "lxmert", None) is not None:
with tf.name_scope(self.lxmert.name):
self.lxmert.build(None)
if getattr(self, "cls", None) is not None:
with tf.name_scope(self.cls.name):
self.cls.build(None)
if getattr(self, "obj_predict_head", None) is not None:
with tf.name_scope(self.obj_predict_head.name):
self.obj_predict_head.build(None)
if getattr(self, "answer_head", None) is not None:
with tf.name_scope(self.answer_head.name):
self.answer_head.build(None)

View File

@ -40,7 +40,6 @@ from ...modeling_tf_utils import (
)
from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
from ...utils import (
ContextManagers,
add_code_sample_docstrings,
add_end_docstrings,
add_start_docstrings,
@ -328,6 +327,23 @@ class TFMarianAttention(tf.keras.layers.Layer):
return attn_output, attn_weights, past_key_value
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "k_proj", None) is not None:
with tf.name_scope(self.k_proj.name):
self.k_proj.build([None, None, self.embed_dim])
if getattr(self, "q_proj", None) is not None:
with tf.name_scope(self.q_proj.name):
self.q_proj.build([None, None, self.embed_dim])
if getattr(self, "v_proj", None) is not None:
with tf.name_scope(self.v_proj.name):
self.v_proj.build([None, None, self.embed_dim])
if getattr(self, "out_proj", None) is not None:
with tf.name_scope(self.out_proj.name):
self.out_proj.build([None, None, self.embed_dim])
# Copied from transformers.models.bart.modeling_tf_bart.TFBartEncoderLayer with Bart->Marian
class TFMarianEncoderLayer(tf.keras.layers.Layer):
@ -344,6 +360,7 @@ class TFMarianEncoderLayer(tf.keras.layers.Layer):
self.fc1 = tf.keras.layers.Dense(config.encoder_ffn_dim, name="fc1")
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
self.config = config
def call(
self,
@ -385,6 +402,26 @@ class TFMarianEncoderLayer(tf.keras.layers.Layer):
return hidden_states, self_attn_weights
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "self_attn", None) is not None:
with tf.name_scope(self.self_attn.name):
self.self_attn.build(None)
if getattr(self, "self_attn_layer_norm", None) is not None:
with tf.name_scope(self.self_attn_layer_norm.name):
self.self_attn_layer_norm.build([None, None, self.embed_dim])
if getattr(self, "fc1", None) is not None:
with tf.name_scope(self.fc1.name):
self.fc1.build([None, None, self.embed_dim])
if getattr(self, "fc2", None) is not None:
with tf.name_scope(self.fc2.name):
self.fc2.build([None, None, self.config.encoder_ffn_dim])
if getattr(self, "final_layer_norm", None) is not None:
with tf.name_scope(self.final_layer_norm.name):
self.final_layer_norm.build([None, None, self.embed_dim])
# Copied from transformers.models.bart.modeling_tf_bart.TFBartDecoderLayer with Bart->Marian
class TFMarianDecoderLayer(tf.keras.layers.Layer):
@ -414,6 +451,7 @@ class TFMarianDecoderLayer(tf.keras.layers.Layer):
self.fc1 = tf.keras.layers.Dense(config.decoder_ffn_dim, name="fc1")
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
self.config = config
def call(
self,
@ -495,6 +533,32 @@ class TFMarianDecoderLayer(tf.keras.layers.Layer):
present_key_value,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "self_attn", None) is not None:
with tf.name_scope(self.self_attn.name):
self.self_attn.build(None)
if getattr(self, "self_attn_layer_norm", None) is not None:
with tf.name_scope(self.self_attn_layer_norm.name):
self.self_attn_layer_norm.build([None, None, self.embed_dim])
if getattr(self, "encoder_attn", None) is not None:
with tf.name_scope(self.encoder_attn.name):
self.encoder_attn.build(None)
if getattr(self, "encoder_attn_layer_norm", None) is not None:
with tf.name_scope(self.encoder_attn_layer_norm.name):
self.encoder_attn_layer_norm.build([None, None, self.embed_dim])
if getattr(self, "fc1", None) is not None:
with tf.name_scope(self.fc1.name):
self.fc1.build([None, None, self.embed_dim])
if getattr(self, "fc2", None) is not None:
with tf.name_scope(self.fc2.name):
self.fc2.build([None, None, self.config.decoder_ffn_dim])
if getattr(self, "final_layer_norm", None) is not None:
with tf.name_scope(self.final_layer_norm.name):
self.final_layer_norm.build([None, None, self.embed_dim])
class TFMarianPreTrainedModel(TFPreTrainedModel):
config_class = MarianConfig
@ -743,16 +807,8 @@ class TFMarianEncoder(tf.keras.layers.Layer):
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None:
# if `self.embed_tokens.load_weight_prefix` is set, runs the embedding operation with the correct name
# scope, so that its weights are registered with the desired name for loading/storing. When `tf.name_scope`
# is used with a name ending in `/`, that name replaces the current name scope.
# (embeddings with tf.name_scope: self.embed_tokens.load_weight_prefix/self.embed_tokens.name/embeddings:0)
context = []
if hasattr(self.embed_tokens, "load_weight_prefix"):
context.append(tf.name_scope(self.embed_tokens.load_weight_prefix + "/"))
with ContextManagers(context):
check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim)
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim)
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
embed_pos = self.embed_positions(input_shape)
hidden_states = inputs_embeds + embed_pos
@ -806,6 +862,18 @@ class TFMarianEncoder(tf.keras.layers.Layer):
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embed_positions", None) is not None:
with tf.name_scope(self.embed_positions.name):
self.embed_positions.build(None)
if getattr(self, "layers", None) is not None:
for layer in self.layers:
with tf.name_scope(layer.name):
layer.build(None)
@keras_serializable
class TFMarianDecoder(tf.keras.layers.Layer):
@ -946,16 +1014,8 @@ class TFMarianDecoder(tf.keras.layers.Layer):
positions = self.embed_positions(input_shape, position_ids=position_ids)
if inputs_embeds is None:
# if `self.embed_tokens.load_weight_prefix` is set, runs the embedding operation with the correct name
# scope, so that its weights are registered with the desired name for loading/storing. When `tf.name_scope`
# is used with a name ending in `/`, that name replaces the current name scope.
# (embeddings with tf.name_scope: self.embed_tokens.load_weight_prefix/self.embed_tokens.name/embeddings:0)
context = []
if hasattr(self.embed_tokens, "load_weight_prefix"):
context.append(tf.name_scope(self.embed_tokens.load_weight_prefix + "/"))
with ContextManagers(context):
check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim)
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim)
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
hidden_states = inputs_embeds
@ -1038,6 +1098,18 @@ class TFMarianDecoder(tf.keras.layers.Layer):
cross_attentions=all_cross_attns,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embed_positions", None) is not None:
with tf.name_scope(self.embed_positions.name):
self.embed_positions.build(None)
if getattr(self, "layers", None) is not None:
for layer in self.layers:
with tf.name_scope(layer.name):
layer.build(None)
@keras_serializable
class TFMarianMainLayer(tf.keras.layers.Layer):
@ -1149,6 +1221,22 @@ class TFMarianMainLayer(tf.keras.layers.Layer):
encoder_attentions=encoder_outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
# The shared/tied weights expect to be in the model base namespace
# Adding "/" to the end (not the start!) of a tf.name_scope puts it in the root namespace rather than
# the current one.
with tf.name_scope(self.shared.load_weight_prefix + "/" + self.shared.name + "/"):
self.shared.build(None)
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
if getattr(self, "decoder", None) is not None:
with tf.name_scope(self.decoder.name):
self.decoder.build(None)
@add_start_docstrings(
"The bare MARIAN Model outputting raw hidden-states without any specific head on top.",
@ -1236,6 +1324,14 @@ class TFMarianModel(TFMarianPreTrainedModel):
encoder_attentions=enc_attns,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "model", None) is not None:
with tf.name_scope(self.model.name):
self.model.build(None)
# Copied from transformers.models.bart.modeling_tf_bart.BiasLayer
class BiasLayer(tf.keras.layers.Layer):
@ -1443,3 +1539,14 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "model", None) is not None:
with tf.name_scope(self.model.name):
self.model.build(None)
if getattr(self, "bias_layer", None) is not None:
with tf.name_scope(self.bias_layer.name):
self.bias_layer.build(None)

View File

@ -40,7 +40,6 @@ from ...modeling_tf_utils import (
)
from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
from ...utils import (
ContextManagers,
add_code_sample_docstrings,
add_end_docstrings,
add_start_docstrings,
@ -297,6 +296,23 @@ class TFMBartAttention(tf.keras.layers.Layer):
return attn_output, attn_weights, past_key_value
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "k_proj", None) is not None:
with tf.name_scope(self.k_proj.name):
self.k_proj.build([None, None, self.embed_dim])
if getattr(self, "q_proj", None) is not None:
with tf.name_scope(self.q_proj.name):
self.q_proj.build([None, None, self.embed_dim])
if getattr(self, "v_proj", None) is not None:
with tf.name_scope(self.v_proj.name):
self.v_proj.build([None, None, self.embed_dim])
if getattr(self, "out_proj", None) is not None:
with tf.name_scope(self.out_proj.name):
self.out_proj.build([None, None, self.embed_dim])
class TFMBartEncoderLayer(tf.keras.layers.Layer):
def __init__(self, config: MBartConfig, **kwargs):
@ -312,6 +328,7 @@ class TFMBartEncoderLayer(tf.keras.layers.Layer):
self.fc1 = tf.keras.layers.Dense(config.encoder_ffn_dim, name="fc1")
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
self.config = config
def call(
self,
@ -353,6 +370,26 @@ class TFMBartEncoderLayer(tf.keras.layers.Layer):
return hidden_states, self_attn_weights
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "self_attn", None) is not None:
with tf.name_scope(self.self_attn.name):
self.self_attn.build(None)
if getattr(self, "self_attn_layer_norm", None) is not None:
with tf.name_scope(self.self_attn_layer_norm.name):
self.self_attn_layer_norm.build([None, None, self.embed_dim])
if getattr(self, "fc1", None) is not None:
with tf.name_scope(self.fc1.name):
self.fc1.build([None, None, self.embed_dim])
if getattr(self, "fc2", None) is not None:
with tf.name_scope(self.fc2.name):
self.fc2.build([None, None, self.config.encoder_ffn_dim])
if getattr(self, "final_layer_norm", None) is not None:
with tf.name_scope(self.final_layer_norm.name):
self.final_layer_norm.build([None, None, self.embed_dim])
class TFMBartDecoderLayer(tf.keras.layers.Layer):
def __init__(self, config: MBartConfig, **kwargs):
@ -381,6 +418,7 @@ class TFMBartDecoderLayer(tf.keras.layers.Layer):
self.fc1 = tf.keras.layers.Dense(config.decoder_ffn_dim, name="fc1")
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
self.config = config
def call(
self,
@ -462,6 +500,32 @@ class TFMBartDecoderLayer(tf.keras.layers.Layer):
present_key_value,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "self_attn", None) is not None:
with tf.name_scope(self.self_attn.name):
self.self_attn.build(None)
if getattr(self, "self_attn_layer_norm", None) is not None:
with tf.name_scope(self.self_attn_layer_norm.name):
self.self_attn_layer_norm.build([None, None, self.embed_dim])
if getattr(self, "encoder_attn", None) is not None:
with tf.name_scope(self.encoder_attn.name):
self.encoder_attn.build(None)
if getattr(self, "encoder_attn_layer_norm", None) is not None:
with tf.name_scope(self.encoder_attn_layer_norm.name):
self.encoder_attn_layer_norm.build([None, None, self.embed_dim])
if getattr(self, "fc1", None) is not None:
with tf.name_scope(self.fc1.name):
self.fc1.build([None, None, self.embed_dim])
if getattr(self, "fc2", None) is not None:
with tf.name_scope(self.fc2.name):
self.fc2.build([None, None, self.config.decoder_ffn_dim])
if getattr(self, "final_layer_norm", None) is not None:
with tf.name_scope(self.final_layer_norm.name):
self.final_layer_norm.build([None, None, self.embed_dim])
class TFMBartPreTrainedModel(TFPreTrainedModel):
config_class = MBartConfig
@ -663,6 +727,7 @@ class TFMBartEncoder(tf.keras.layers.Layer):
self.layers = [TFMBartEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)]
self.layernorm_embedding = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding")
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm")
self.embed_dim = config.d_model
def get_embed_tokens(self):
return self.embed_tokens
@ -735,16 +800,8 @@ class TFMBartEncoder(tf.keras.layers.Layer):
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None:
# if `self.embed_tokens.load_weight_prefix` is set, runs the embedding operation with the correct name
# scope, so that its weights are registered with the desired name for loading/storing. When `tf.name_scope`
# is used with a name ending in `/`, that name replaces the current name scope.
# (embeddings with tf.name_scope: self.embed_tokens.load_weight_prefix/self.embed_tokens.name/embeddings:0)
context = []
if hasattr(self.embed_tokens, "load_weight_prefix"):
context.append(tf.name_scope(self.embed_tokens.load_weight_prefix + "/"))
with ContextManagers(context):
check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim)
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim)
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
embed_pos = self.embed_positions(input_shape)
hidden_states = inputs_embeds + embed_pos
@ -801,6 +858,24 @@ class TFMBartEncoder(tf.keras.layers.Layer):
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embed_positions", None) is not None:
with tf.name_scope(self.embed_positions.name):
self.embed_positions.build(None)
if getattr(self, "layernorm_embedding", None) is not None:
with tf.name_scope(self.layernorm_embedding.name):
self.layernorm_embedding.build([None, None, self.embed_dim])
if getattr(self, "layer_norm", None) is not None:
with tf.name_scope(self.layer_norm.name):
self.layer_norm.build([None, None, self.config.d_model])
if getattr(self, "layers", None) is not None:
for layer in self.layers:
with tf.name_scope(layer.name):
layer.build(None)
@keras_serializable
class TFMBartDecoder(tf.keras.layers.Layer):
@ -945,16 +1020,8 @@ class TFMBartDecoder(tf.keras.layers.Layer):
positions = self.embed_positions(input_shape, position_ids=position_ids)
if inputs_embeds is None:
# if `self.embed_tokens.load_weight_prefix` is set, runs the embedding operation with the correct name
# scope, so that its weights are registered with the desired name for loading/storing. When `tf.name_scope`
# is used with a name ending in `/`, that name replaces the current name scope.
# (embeddings with tf.name_scope: self.embed_tokens.load_weight_prefix/self.embed_tokens.name/embeddings:0)
context = []
if hasattr(self.embed_tokens, "load_weight_prefix"):
context.append(tf.name_scope(self.embed_tokens.load_weight_prefix + "/"))
with ContextManagers(context):
check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim)
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim)
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
hidden_states = inputs_embeds
@ -1040,6 +1107,24 @@ class TFMBartDecoder(tf.keras.layers.Layer):
cross_attentions=all_cross_attns,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embed_positions", None) is not None:
with tf.name_scope(self.embed_positions.name):
self.embed_positions.build(None)
if getattr(self, "layernorm_embedding", None) is not None:
with tf.name_scope(self.layernorm_embedding.name):
self.layernorm_embedding.build([None, None, self.config.d_model])
if getattr(self, "layer_norm", None) is not None:
with tf.name_scope(self.layer_norm.name):
self.layer_norm.build([None, None, self.config.d_model])
if getattr(self, "layers", None) is not None:
for layer in self.layers:
with tf.name_scope(layer.name):
layer.build(None)
@keras_serializable
class TFMBartMainLayer(tf.keras.layers.Layer):
@ -1154,6 +1239,22 @@ class TFMBartMainLayer(tf.keras.layers.Layer):
encoder_attentions=encoder_outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
# The shared/tied weights expect to be in the model base namespace
# Adding "/" to the end (not the start!) of a tf.name_scope puts it in the root namespace rather than
# the current one.
with tf.name_scope(self.shared.load_weight_prefix + "/" + self.shared.name + "/"):
self.shared.build(None)
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
if getattr(self, "decoder", None) is not None:
with tf.name_scope(self.decoder.name):
self.decoder.build(None)
@add_start_docstrings(
"The bare MBART Model outputting raw hidden-states without any specific head on top.",
@ -1241,6 +1342,14 @@ class TFMBartModel(TFMBartPreTrainedModel):
encoder_attentions=enc_attns,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "model", None) is not None:
with tf.name_scope(self.model.name):
self.model.build(None)
# Copied from transformers.models.bart.modeling_tf_bart.BiasLayer
class BiasLayer(tf.keras.layers.Layer):
@ -1446,3 +1555,14 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "model", None) is not None:
with tf.name_scope(self.model.name):
self.model.build(None)
if getattr(self, "bias_layer", None) is not None:
with tf.name_scope(self.bias_layer.name):
self.bias_layer.build(None)

View File

@ -130,6 +130,7 @@ class TFMobileBertIntermediate(tf.keras.layers.Layer):
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
else:
self.intermediate_act_fn = config.hidden_act
self.config = config
def call(self, hidden_states):
hidden_states = self.dense(hidden_states)
@ -137,11 +138,23 @@ class TFMobileBertIntermediate(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.true_hidden_size])
class TFLayerNorm(tf.keras.layers.LayerNormalization):
def __init__(self, feat_size, *args, **kwargs):
self.feat_size = feat_size
super().__init__(*args, **kwargs)
def build(self, input_shape=None):
super().build([None, None, self.feat_size])
class TFNoNorm(tf.keras.layers.Layer):
def __init__(self, feat_size, epsilon=None, **kwargs):
@ -180,8 +193,9 @@ class TFMobileBertEmbeddings(tf.keras.layers.Layer):
config.hidden_size, epsilon=config.layer_norm_eps, name="LayerNorm"
)
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
self.embedded_input_size = self.embedding_size * (3 if self.trigram_input else 1)
def build(self, input_shape):
def build(self, input_shape=None):
with tf.name_scope("word_embeddings"):
self.weight = self.add_weight(
name="weight",
@ -203,7 +217,15 @@ class TFMobileBertEmbeddings(tf.keras.layers.Layer):
initializer=get_initializer(initializer_range=self.initializer_range),
)
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "embedding_transformation", None) is not None:
with tf.name_scope(self.embedding_transformation.name):
self.embedding_transformation.build([None, None, self.embedded_input_size])
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build(None)
def call(self, input_ids=None, position_ids=None, token_type_ids=None, inputs_embeds=None, training=False):
"""
@ -281,6 +303,7 @@ class TFMobileBertSelfAttention(tf.keras.layers.Layer):
)
self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)
self.config = config
def transpose_for_scores(self, x, batch_size):
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
@ -332,6 +355,28 @@ class TFMobileBertSelfAttention(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "query", None) is not None:
with tf.name_scope(self.query.name):
self.query.build([None, None, self.config.true_hidden_size])
if getattr(self, "key", None) is not None:
with tf.name_scope(self.key.name):
self.key.build([None, None, self.config.true_hidden_size])
if getattr(self, "value", None) is not None:
with tf.name_scope(self.value.name):
self.value.build(
[
None,
None,
self.config.true_hidden_size
if self.config.use_bottleneck_attention
else self.config.hidden_size,
]
)
class TFMobileBertSelfOutput(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
@ -345,6 +390,7 @@ class TFMobileBertSelfOutput(tf.keras.layers.Layer):
)
if not self.use_bottleneck:
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
self.config = config
def call(self, hidden_states, residual_tensor, training=False):
hidden_states = self.dense(hidden_states)
@ -353,6 +399,17 @@ class TFMobileBertSelfOutput(tf.keras.layers.Layer):
hidden_states = self.LayerNorm(hidden_states + residual_tensor)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.true_hidden_size])
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build(None)
class TFMobileBertAttention(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
@ -382,6 +439,17 @@ class TFMobileBertAttention(tf.keras.layers.Layer):
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "self", None) is not None:
with tf.name_scope(self.self.name):
self.self.build(None)
if getattr(self, "mobilebert_output", None) is not None:
with tf.name_scope(self.mobilebert_output.name):
self.mobilebert_output.build(None)
class TFOutputBottleneck(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
@ -391,6 +459,7 @@ class TFOutputBottleneck(tf.keras.layers.Layer):
config.hidden_size, epsilon=config.layer_norm_eps, name="LayerNorm"
)
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
self.config = config
def call(self, hidden_states, residual_tensor, training=False):
layer_outputs = self.dense(hidden_states)
@ -398,6 +467,17 @@ class TFOutputBottleneck(tf.keras.layers.Layer):
layer_outputs = self.LayerNorm(layer_outputs + residual_tensor)
return layer_outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.true_hidden_size])
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build(None)
class TFMobileBertOutput(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
@ -413,6 +493,7 @@ class TFMobileBertOutput(tf.keras.layers.Layer):
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
else:
self.bottleneck = TFOutputBottleneck(config, name="bottleneck")
self.config = config
def call(self, hidden_states, residual_tensor_1, residual_tensor_2, training=False):
hidden_states = self.dense(hidden_states)
@ -424,6 +505,20 @@ class TFMobileBertOutput(tf.keras.layers.Layer):
hidden_states = self.bottleneck(hidden_states, residual_tensor_2)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.intermediate_size])
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build(None)
if getattr(self, "bottleneck", None) is not None:
with tf.name_scope(self.bottleneck.name):
self.bottleneck.build(None)
class TFBottleneckLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
@ -432,12 +527,24 @@ class TFBottleneckLayer(tf.keras.layers.Layer):
self.LayerNorm = NORM2FN[config.normalization_type](
config.intra_bottleneck_size, epsilon=config.layer_norm_eps, name="LayerNorm"
)
self.config = config
def call(self, inputs):
hidden_states = self.dense(inputs)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build(None)
class TFBottleneck(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
@ -474,6 +581,17 @@ class TFBottleneck(tf.keras.layers.Layer):
else:
return (hidden_states, hidden_states, hidden_states, bottlenecked_hidden_states)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "bottleneck_input", None) is not None:
with tf.name_scope(self.bottleneck_input.name):
self.bottleneck_input.build(None)
if getattr(self, "attention", None) is not None:
with tf.name_scope(self.attention.name):
self.attention.build(None)
class TFFFNOutput(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
@ -482,12 +600,24 @@ class TFFFNOutput(tf.keras.layers.Layer):
self.LayerNorm = NORM2FN[config.normalization_type](
config.true_hidden_size, epsilon=config.layer_norm_eps, name="LayerNorm"
)
self.config = config
def call(self, hidden_states, residual_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.LayerNorm(hidden_states + residual_tensor)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.intermediate_size])
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build(None)
class TFFFNLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
@ -500,6 +630,17 @@ class TFFFNLayer(tf.keras.layers.Layer):
layer_outputs = self.mobilebert_output(intermediate_output, hidden_states)
return layer_outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "intermediate", None) is not None:
with tf.name_scope(self.intermediate.name):
self.intermediate.build(None)
if getattr(self, "mobilebert_output", None) is not None:
with tf.name_scope(self.mobilebert_output.name):
self.mobilebert_output.build(None)
class TFMobileBertLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
@ -560,6 +701,27 @@ class TFMobileBertLayer(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "attention", None) is not None:
with tf.name_scope(self.attention.name):
self.attention.build(None)
if getattr(self, "intermediate", None) is not None:
with tf.name_scope(self.intermediate.name):
self.intermediate.build(None)
if getattr(self, "mobilebert_output", None) is not None:
with tf.name_scope(self.mobilebert_output.name):
self.mobilebert_output.build(None)
if getattr(self, "bottleneck", None) is not None:
with tf.name_scope(self.bottleneck.name):
self.bottleneck.build(None)
if getattr(self, "ffn", None) is not None:
for layer in self.ffn:
with tf.name_scope(layer.name):
layer.build(None)
class TFMobileBertEncoder(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
@ -603,6 +765,15 @@ class TFMobileBertEncoder(tf.keras.layers.Layer):
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "layer", None) is not None:
for layer in self.layer:
with tf.name_scope(layer.name):
layer.build(None)
class TFMobileBertPooler(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
@ -615,6 +786,7 @@ class TFMobileBertPooler(tf.keras.layers.Layer):
activation="tanh",
name="dense",
)
self.config = config
def call(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
@ -626,6 +798,14 @@ class TFMobileBertPooler(tf.keras.layers.Layer):
pooled_output = self.dense(first_token_tensor)
return pooled_output
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
class TFMobileBertPredictionHeadTransform(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
@ -638,6 +818,7 @@ class TFMobileBertPredictionHeadTransform(tf.keras.layers.Layer):
else:
self.transform_act_fn = config.hidden_act
self.LayerNorm = NORM2FN["layer_norm"](config.hidden_size, epsilon=config.layer_norm_eps, name="LayerNorm")
self.config = config
def call(self, hidden_states):
hidden_states = self.dense(hidden_states)
@ -645,6 +826,17 @@ class TFMobileBertPredictionHeadTransform(tf.keras.layers.Layer):
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build(None)
class TFMobileBertLMPredictionHead(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
@ -652,7 +844,7 @@ class TFMobileBertLMPredictionHead(tf.keras.layers.Layer):
self.transform = TFMobileBertPredictionHeadTransform(config, name="transform")
self.config = config
def build(self, input_shape):
def build(self, input_shape=None):
self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias")
self.dense = self.add_weight(
shape=(self.config.hidden_size - self.config.embedding_size, self.config.vocab_size),
@ -666,7 +858,13 @@ class TFMobileBertLMPredictionHead(tf.keras.layers.Layer):
trainable=True,
name="decoder/weight",
)
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "transform", None) is not None:
with tf.name_scope(self.transform.name):
self.transform.build(None)
def get_output_embeddings(self):
return self
@ -698,6 +896,14 @@ class TFMobileBertMLMHead(tf.keras.layers.Layer):
prediction_scores = self.predictions(sequence_output)
return prediction_scores
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "predictions", None) is not None:
with tf.name_scope(self.predictions.name):
self.predictions.build(None)
@keras_serializable
class TFMobileBertMainLayer(tf.keras.layers.Layer):
@ -814,6 +1020,20 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer):
attentions=encoder_outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embeddings", None) is not None:
with tf.name_scope(self.embeddings.name):
self.embeddings.build(None)
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
if getattr(self, "pooler", None) is not None:
with tf.name_scope(self.pooler.name):
self.pooler.build(None)
class TFMobileBertPreTrainedModel(TFPreTrainedModel):
"""
@ -998,6 +1218,14 @@ class TFMobileBertModel(TFMobileBertPreTrainedModel):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "mobilebert", None) is not None:
with tf.name_scope(self.mobilebert.name):
self.mobilebert.build(None)
@add_start_docstrings(
"""
@ -1011,7 +1239,7 @@ class TFMobileBertForPreTraining(TFMobileBertPreTrainedModel, TFMobileBertPreTra
super().__init__(config, *inputs, **kwargs)
self.mobilebert = TFMobileBertMainLayer(config, name="mobilebert")
self.predictions = TFMobileBertMLMHead(config, name="predictions___cls")
self.seq_relationship = TFMobileBertOnlyNSPHead(2, name="seq_relationship___cls")
self.seq_relationship = TFMobileBertOnlyNSPHead(config, name="seq_relationship___cls")
def get_lm_head(self):
return self.predictions.predictions
@ -1088,6 +1316,20 @@ class TFMobileBertForPreTraining(TFMobileBertPreTrainedModel, TFMobileBertPreTra
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "mobilebert", None) is not None:
with tf.name_scope(self.mobilebert.name):
self.mobilebert.build(None)
if getattr(self, "predictions", None) is not None:
with tf.name_scope(self.predictions.name):
self.predictions.build(None)
if getattr(self, "seq_relationship", None) is not None:
with tf.name_scope(self.seq_relationship.name):
self.seq_relationship.build(None)
def tf_to_pt_weight_rename(self, tf_weight):
if tf_weight == "cls.predictions.decoder.weight":
return tf_weight, "mobilebert.embeddings.word_embeddings.weight"
@ -1174,6 +1416,17 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "mobilebert", None) is not None:
with tf.name_scope(self.mobilebert.name):
self.mobilebert.build(None)
if getattr(self, "predictions", None) is not None:
with tf.name_scope(self.predictions.name):
self.predictions.build(None)
def tf_to_pt_weight_rename(self, tf_weight):
if tf_weight == "cls.predictions.decoder.weight":
return tf_weight, "mobilebert.embeddings.word_embeddings.weight"
@ -1185,11 +1438,20 @@ class TFMobileBertOnlyNSPHead(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.seq_relationship = tf.keras.layers.Dense(2, name="seq_relationship")
self.config = config
def call(self, pooled_output):
seq_relationship_score = self.seq_relationship(pooled_output)
return seq_relationship_score
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "seq_relationship", None) is not None:
with tf.name_scope(self.seq_relationship.name):
self.seq_relationship.build([None, None, self.config.hidden_size])
@add_start_docstrings(
"""MobileBert Model with a `next sentence prediction (classification)` head on top.""",
@ -1272,6 +1534,17 @@ class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel, TFNextS
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "mobilebert", None) is not None:
with tf.name_scope(self.mobilebert.name):
self.mobilebert.build(None)
if getattr(self, "cls", None) is not None:
with tf.name_scope(self.cls.name):
self.cls.build(None)
@add_start_docstrings(
"""
@ -1302,6 +1575,7 @@ class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSeque
self.classifier = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ -1362,6 +1636,17 @@ class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSeque
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "mobilebert", None) is not None:
with tf.name_scope(self.mobilebert.name):
self.mobilebert.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.hidden_size])
@add_start_docstrings(
"""
@ -1388,6 +1673,7 @@ class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAn
self.qa_outputs = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ -1461,6 +1747,17 @@ class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAn
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "mobilebert", None) is not None:
with tf.name_scope(self.mobilebert.name):
self.mobilebert.build(None)
if getattr(self, "qa_outputs", None) is not None:
with tf.name_scope(self.qa_outputs.name):
self.qa_outputs.build([None, None, self.config.hidden_size])
@add_start_docstrings(
"""
@ -1487,6 +1784,7 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic
self.classifier = tf.keras.layers.Dense(
1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(
@ -1562,6 +1860,17 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "mobilebert", None) is not None:
with tf.name_scope(self.mobilebert.name):
self.mobilebert.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.hidden_size])
@add_start_docstrings(
"""
@ -1593,6 +1902,7 @@ class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenCla
self.classifier = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ -1650,3 +1960,14 @@ class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenCla
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "mobilebert", None) is not None:
with tf.name_scope(self.mobilebert.name):
self.mobilebert.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.hidden_size])

View File

@ -85,6 +85,7 @@ class TFMobileViTConvLayer(tf.keras.layers.Layer):
def __init__(
self,
config: MobileViTConfig,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
@ -132,6 +133,8 @@ class TFMobileViTConvLayer(tf.keras.layers.Layer):
self.activation = config.hidden_act
else:
self.activation = None
self.in_channels = in_channels
self.out_channels = out_channels
def call(self, features: tf.Tensor, training: bool = False) -> tf.Tensor:
padded_features = self.padding(features)
@ -142,6 +145,18 @@ class TFMobileViTConvLayer(tf.keras.layers.Layer):
features = self.activation(features)
return features
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "convolution", None) is not None:
with tf.name_scope(self.convolution.name):
self.convolution.build([None, None, None, self.in_channels])
if getattr(self, "normalization", None) is not None:
if hasattr(self.normalization, "name"):
with tf.name_scope(self.normalization.name):
self.normalization.build([None, None, None, self.out_channels])
class TFMobileViTInvertedResidual(tf.keras.layers.Layer):
"""
@ -160,11 +175,12 @@ class TFMobileViTInvertedResidual(tf.keras.layers.Layer):
self.use_residual = (stride == 1) and (in_channels == out_channels)
self.expand_1x1 = TFMobileViTConvLayer(
config, out_channels=expanded_channels, kernel_size=1, name="expand_1x1"
config, in_channels=in_channels, out_channels=expanded_channels, kernel_size=1, name="expand_1x1"
)
self.conv_3x3 = TFMobileViTConvLayer(
config,
in_channels=expanded_channels,
out_channels=expanded_channels,
kernel_size=3,
stride=stride,
@ -175,6 +191,7 @@ class TFMobileViTInvertedResidual(tf.keras.layers.Layer):
self.reduce_1x1 = TFMobileViTConvLayer(
config,
in_channels=expanded_channels,
out_channels=out_channels,
kernel_size=1,
use_activation=False,
@ -190,6 +207,20 @@ class TFMobileViTInvertedResidual(tf.keras.layers.Layer):
return residual + features if self.use_residual else features
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "expand_1x1", None) is not None:
with tf.name_scope(self.expand_1x1.name):
self.expand_1x1.build(None)
if getattr(self, "conv_3x3", None) is not None:
with tf.name_scope(self.conv_3x3.name):
self.conv_3x3.build(None)
if getattr(self, "reduce_1x1", None) is not None:
with tf.name_scope(self.reduce_1x1.name):
self.reduce_1x1.build(None)
class TFMobileViTMobileNetLayer(tf.keras.layers.Layer):
def __init__(
@ -220,6 +251,15 @@ class TFMobileViTMobileNetLayer(tf.keras.layers.Layer):
features = layer_module(features, training=training)
return features
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "layers", None) is not None:
for layer_module in self.layers:
with tf.name_scope(layer_module.name):
layer_module.build(None)
class TFMobileViTSelfAttention(tf.keras.layers.Layer):
def __init__(self, config: MobileViTConfig, hidden_size: int, **kwargs) -> None:
@ -242,6 +282,7 @@ class TFMobileViTSelfAttention(tf.keras.layers.Layer):
self.value = tf.keras.layers.Dense(self.all_head_size, use_bias=config.qkv_bias, name="value")
self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)
self.hidden_size = hidden_size
def transpose_for_scores(self, x: tf.Tensor) -> tf.Tensor:
batch_size = tf.shape(x)[0]
@ -272,18 +313,41 @@ class TFMobileViTSelfAttention(tf.keras.layers.Layer):
context_layer = tf.reshape(context_layer, shape=(batch_size, -1, self.all_head_size))
return context_layer
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "query", None) is not None:
with tf.name_scope(self.query.name):
self.query.build([None, None, self.hidden_size])
if getattr(self, "key", None) is not None:
with tf.name_scope(self.key.name):
self.key.build([None, None, self.hidden_size])
if getattr(self, "value", None) is not None:
with tf.name_scope(self.value.name):
self.value.build([None, None, self.hidden_size])
class TFMobileViTSelfOutput(tf.keras.layers.Layer):
def __init__(self, config: MobileViTConfig, hidden_size: int, **kwargs) -> None:
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(hidden_size, name="dense")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
self.hidden_size = hidden_size
def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states, training=training)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.hidden_size])
class TFMobileViTAttention(tf.keras.layers.Layer):
def __init__(self, config: MobileViTConfig, hidden_size: int, **kwargs) -> None:
@ -299,6 +363,17 @@ class TFMobileViTAttention(tf.keras.layers.Layer):
attention_output = self.dense_output(self_outputs, training=training)
return attention_output
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "attention", None) is not None:
with tf.name_scope(self.attention.name):
self.attention.build(None)
if getattr(self, "dense_output", None) is not None:
with tf.name_scope(self.dense_output.name):
self.dense_output.build(None)
class TFMobileViTIntermediate(tf.keras.layers.Layer):
def __init__(self, config: MobileViTConfig, hidden_size: int, intermediate_size: int, **kwargs) -> None:
@ -308,18 +383,28 @@ class TFMobileViTIntermediate(tf.keras.layers.Layer):
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
else:
self.intermediate_act_fn = config.hidden_act
self.hidden_size = hidden_size
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.hidden_size])
class TFMobileViTOutput(tf.keras.layers.Layer):
def __init__(self, config: MobileViTConfig, hidden_size: int, intermediate_size: int, **kwargs) -> None:
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(hidden_size, name="dense")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
self.intermediate_size = intermediate_size
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(hidden_states)
@ -327,6 +412,14 @@ class TFMobileViTOutput(tf.keras.layers.Layer):
hidden_states = hidden_states + input_tensor
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.intermediate_size])
class TFMobileViTTransformerLayer(tf.keras.layers.Layer):
def __init__(self, config: MobileViTConfig, hidden_size: int, intermediate_size: int, **kwargs) -> None:
@ -340,6 +433,7 @@ class TFMobileViTTransformerLayer(tf.keras.layers.Layer):
self.layernorm_after = tf.keras.layers.LayerNormalization(
epsilon=config.layer_norm_eps, name="layernorm_after"
)
self.hidden_size = hidden_size
def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
attention_output = self.attention(self.layernorm_before(hidden_states), training=training)
@ -350,6 +444,26 @@ class TFMobileViTTransformerLayer(tf.keras.layers.Layer):
layer_output = self.mobilevit_output(layer_output, hidden_states, training=training)
return layer_output
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "attention", None) is not None:
with tf.name_scope(self.attention.name):
self.attention.build(None)
if getattr(self, "intermediate", None) is not None:
with tf.name_scope(self.intermediate.name):
self.intermediate.build(None)
if getattr(self, "mobilevit_output", None) is not None:
with tf.name_scope(self.mobilevit_output.name):
self.mobilevit_output.build(None)
if getattr(self, "layernorm_before", None) is not None:
with tf.name_scope(self.layernorm_before.name):
self.layernorm_before.build([None, None, self.hidden_size])
if getattr(self, "layernorm_after", None) is not None:
with tf.name_scope(self.layernorm_after.name):
self.layernorm_after.build([None, None, self.hidden_size])
class TFMobileViTTransformer(tf.keras.layers.Layer):
def __init__(self, config: MobileViTConfig, hidden_size: int, num_stages: int, **kwargs) -> None:
@ -370,6 +484,15 @@ class TFMobileViTTransformer(tf.keras.layers.Layer):
hidden_states = layer_module(hidden_states, training=training)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "layers", None) is not None:
for layer_module in self.layers:
with tf.name_scope(layer_module.name):
layer_module.build(None)
class TFMobileViTLayer(tf.keras.layers.Layer):
"""
@ -405,11 +528,16 @@ class TFMobileViTLayer(tf.keras.layers.Layer):
self.downsampling_layer = None
self.conv_kxk = TFMobileViTConvLayer(
config, out_channels=in_channels, kernel_size=config.conv_kernel_size, name="conv_kxk"
config,
in_channels=in_channels,
out_channels=in_channels,
kernel_size=config.conv_kernel_size,
name="conv_kxk",
)
self.conv_1x1 = TFMobileViTConvLayer(
config,
in_channels=in_channels,
out_channels=hidden_size,
kernel_size=1,
use_normalization=False,
@ -424,12 +552,17 @@ class TFMobileViTLayer(tf.keras.layers.Layer):
self.layernorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm")
self.conv_projection = TFMobileViTConvLayer(
config, out_channels=in_channels, kernel_size=1, name="conv_projection"
config, in_channels=hidden_size, out_channels=in_channels, kernel_size=1, name="conv_projection"
)
self.fusion = TFMobileViTConvLayer(
config, out_channels=in_channels, kernel_size=config.conv_kernel_size, name="fusion"
config,
in_channels=2 * in_channels,
out_channels=in_channels,
kernel_size=config.conv_kernel_size,
name="fusion",
)
self.hidden_size = hidden_size
def unfolding(self, features: tf.Tensor) -> Tuple[tf.Tensor, Dict]:
patch_width, patch_height = self.patch_width, self.patch_height
@ -528,6 +661,32 @@ class TFMobileViTLayer(tf.keras.layers.Layer):
features = self.fusion(tf.concat([residual, features], axis=-1), training=training)
return features
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "conv_kxk", None) is not None:
with tf.name_scope(self.conv_kxk.name):
self.conv_kxk.build(None)
if getattr(self, "conv_1x1", None) is not None:
with tf.name_scope(self.conv_1x1.name):
self.conv_1x1.build(None)
if getattr(self, "transformer", None) is not None:
with tf.name_scope(self.transformer.name):
self.transformer.build(None)
if getattr(self, "layernorm", None) is not None:
with tf.name_scope(self.layernorm.name):
self.layernorm.build([None, None, self.hidden_size])
if getattr(self, "conv_projection", None) is not None:
with tf.name_scope(self.conv_projection.name):
self.conv_projection.build(None)
if getattr(self, "fusion", None) is not None:
with tf.name_scope(self.fusion.name):
self.fusion.build(None)
if getattr(self, "downsampling_layer", None) is not None:
with tf.name_scope(self.downsampling_layer.name):
self.downsampling_layer.build(None)
class TFMobileViTEncoder(tf.keras.layers.Layer):
def __init__(self, config: MobileViTConfig, **kwargs) -> None:
@ -628,6 +787,15 @@ class TFMobileViTEncoder(tf.keras.layers.Layer):
return TFBaseModelOutput(last_hidden_state=hidden_states, hidden_states=all_hidden_states)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "layers", None) is not None:
for layer_module in self.layers:
with tf.name_scope(layer_module.name):
layer_module.build(None)
@keras_serializable
class TFMobileViTMainLayer(tf.keras.layers.Layer):
@ -640,6 +808,7 @@ class TFMobileViTMainLayer(tf.keras.layers.Layer):
self.conv_stem = TFMobileViTConvLayer(
config,
in_channels=config.num_channels,
out_channels=config.neck_hidden_sizes[0],
kernel_size=3,
stride=2,
@ -650,7 +819,11 @@ class TFMobileViTMainLayer(tf.keras.layers.Layer):
if self.expand_output:
self.conv_1x1_exp = TFMobileViTConvLayer(
config, out_channels=config.neck_hidden_sizes[6], kernel_size=1, name="conv_1x1_exp"
config,
in_channels=config.neck_hidden_sizes[5],
out_channels=config.neck_hidden_sizes[6],
kernel_size=1,
name="conv_1x1_exp",
)
self.pooler = tf.keras.layers.GlobalAveragePooling2D(data_format="channels_first", name="pooler")
@ -724,6 +897,23 @@ class TFMobileViTMainLayer(tf.keras.layers.Layer):
hidden_states=hidden_states if output_hidden_states else encoder_outputs.hidden_states,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "conv_stem", None) is not None:
with tf.name_scope(self.conv_stem.name):
self.conv_stem.build(None)
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
if getattr(self, "pooler", None) is not None:
with tf.name_scope(self.pooler.name):
self.pooler.build([None, None, None, None])
if getattr(self, "conv_1x1_exp", None) is not None:
with tf.name_scope(self.conv_1x1_exp.name):
self.conv_1x1_exp.build(None)
class TFMobileViTPreTrainedModel(TFPreTrainedModel):
"""
@ -824,6 +1014,14 @@ class TFMobileViTModel(TFMobileViTPreTrainedModel):
output = self.mobilevit(pixel_values, output_hidden_states, return_dict, training=training)
return output
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "mobilevit", None) is not None:
with tf.name_scope(self.mobilevit.name):
self.mobilevit.build(None)
@add_start_docstrings(
"""
@ -844,6 +1042,7 @@ class TFMobileViTForImageClassification(TFMobileViTPreTrainedModel, TFSequenceCl
self.classifier = (
tf.keras.layers.Dense(config.num_labels, name="classifier") if config.num_labels > 0 else tf.identity
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(MOBILEVIT_INPUTS_DOCSTRING)
@ -884,15 +1083,28 @@ class TFMobileViTForImageClassification(TFMobileViTPreTrainedModel, TFSequenceCl
return TFImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "mobilevit", None) is not None:
with tf.name_scope(self.mobilevit.name):
self.mobilevit.build(None)
if getattr(self, "classifier", None) is not None:
if hasattr(self.classifier, "name"):
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.neck_hidden_sizes[-1]])
class TFMobileViTASPPPooling(tf.keras.layers.Layer):
def __init__(self, config: MobileViTConfig, out_channels: int, **kwargs) -> None:
def __init__(self, config: MobileViTConfig, in_channels: int, out_channels: int, **kwargs) -> None:
super().__init__(**kwargs)
self.global_pool = tf.keras.layers.GlobalAveragePooling2D(keepdims=True, name="global_pool")
self.conv_1x1 = TFMobileViTConvLayer(
config,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
@ -908,6 +1120,17 @@ class TFMobileViTASPPPooling(tf.keras.layers.Layer):
features = tf.image.resize(features, size=spatial_size, method="bilinear")
return features
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "global_pool", None) is not None:
with tf.name_scope(self.global_pool.name):
self.global_pool.build([None, None, None, None])
if getattr(self, "conv_1x1", None) is not None:
with tf.name_scope(self.conv_1x1.name):
self.conv_1x1.build(None)
class TFMobileViTASPP(tf.keras.layers.Layer):
"""
@ -917,6 +1140,7 @@ class TFMobileViTASPP(tf.keras.layers.Layer):
def __init__(self, config: MobileViTConfig, **kwargs) -> None:
super().__init__(**kwargs)
in_channels = config.neck_hidden_sizes[-2]
out_channels = config.aspp_out_channels
if len(config.atrous_rates) != 3:
@ -926,6 +1150,7 @@ class TFMobileViTASPP(tf.keras.layers.Layer):
in_projection = TFMobileViTConvLayer(
config,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
use_activation="relu",
@ -937,6 +1162,7 @@ class TFMobileViTASPP(tf.keras.layers.Layer):
[
TFMobileViTConvLayer(
config,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
dilation=rate,
@ -947,11 +1173,14 @@ class TFMobileViTASPP(tf.keras.layers.Layer):
]
)
pool_layer = TFMobileViTASPPPooling(config, out_channels, name=f"convs.{len(config.atrous_rates) + 1}")
pool_layer = TFMobileViTASPPPooling(
config, in_channels, out_channels, name=f"convs.{len(config.atrous_rates) + 1}"
)
self.convs.append(pool_layer)
self.project = TFMobileViTConvLayer(
config,
in_channels=5 * out_channels,
out_channels=out_channels,
kernel_size=1,
use_activation="relu",
@ -973,6 +1202,18 @@ class TFMobileViTASPP(tf.keras.layers.Layer):
pooled_features = self.dropout(pooled_features, training=training)
return pooled_features
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "project", None) is not None:
with tf.name_scope(self.project.name):
self.project.build(None)
if getattr(self, "convs", None) is not None:
for conv in self.convs:
with tf.name_scope(conv.name):
conv.build(None)
class TFMobileViTDeepLabV3(tf.keras.layers.Layer):
"""
@ -987,6 +1228,7 @@ class TFMobileViTDeepLabV3(tf.keras.layers.Layer):
self.classifier = TFMobileViTConvLayer(
config,
in_channels=config.aspp_out_channels,
out_channels=config.num_labels,
kernel_size=1,
use_normalization=False,
@ -1001,6 +1243,17 @@ class TFMobileViTDeepLabV3(tf.keras.layers.Layer):
features = self.classifier(features, training=training)
return features
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "aspp", None) is not None:
with tf.name_scope(self.aspp.name):
self.aspp.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build(None)
@add_start_docstrings(
"""
@ -1113,3 +1366,14 @@ class TFMobileViTForSemanticSegmentation(TFMobileViTPreTrainedModel):
logits=logits,
hidden_states=outputs.hidden_states if output_hidden_states else None,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "mobilevit", None) is not None:
with tf.name_scope(self.mobilevit.name):
self.mobilevit.build(None)
if getattr(self, "segmentation_head", None) is not None:
with tf.name_scope(self.segmentation_head.name):
self.segmentation_head.build(None)

View File

@ -91,7 +91,7 @@ class TFMPNetEmbeddings(tf.keras.layers.Layer):
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
def build(self, input_shape: tf.TensorShape):
def build(self, input_shape=None):
with tf.name_scope("word_embeddings"):
self.weight = self.add_weight(
name="weight",
@ -106,7 +106,12 @@ class TFMPNetEmbeddings(tf.keras.layers.Layer):
initializer=get_initializer(initializer_range=self.initializer_range),
)
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
def create_position_ids_from_input_ids(self, input_ids):
"""
@ -165,6 +170,7 @@ class TFMPNetPooler(tf.keras.layers.Layer):
activation="tanh",
name="dense",
)
self.config = config
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
# We "pool" the model by simply taking the hidden state corresponding
@ -174,6 +180,14 @@ class TFMPNetPooler(tf.keras.layers.Layer):
return pooled_output
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
class TFMPNetSelfAttention(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
@ -203,6 +217,7 @@ class TFMPNetSelfAttention(tf.keras.layers.Layer):
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="o"
)
self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)
self.config = config
def transpose_for_scores(self, x, batch_size):
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
@ -247,6 +262,23 @@ class TFMPNetSelfAttention(tf.keras.layers.Layer):
outputs = (o, attention_probs) if output_attentions else (o,)
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "q", None) is not None:
with tf.name_scope(self.q.name):
self.q.build([None, None, self.config.hidden_size])
if getattr(self, "k", None) is not None:
with tf.name_scope(self.k.name):
self.k.build([None, None, self.config.hidden_size])
if getattr(self, "v", None) is not None:
with tf.name_scope(self.v.name):
self.v.build([None, None, self.config.hidden_size])
if getattr(self, "o", None) is not None:
with tf.name_scope(self.o.name):
self.o.build([None, None, self.config.hidden_size])
class TFMPNetAttention(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
@ -255,6 +287,7 @@ class TFMPNetAttention(tf.keras.layers.Layer):
self.attn = TFMPNetSelfAttention(config, name="attn")
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
self.config = config
def prune_heads(self, heads):
raise NotImplementedError
@ -267,6 +300,17 @@ class TFMPNetAttention(tf.keras.layers.Layer):
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "attn", None) is not None:
with tf.name_scope(self.attn.name):
self.attn.build(None)
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
# Copied from transformers.models.bert.modeling_tf_bert.TFBertIntermediate with Bert->MPNet
class TFMPNetIntermediate(tf.keras.layers.Layer):
@ -281,6 +325,7 @@ class TFMPNetIntermediate(tf.keras.layers.Layer):
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
else:
self.intermediate_act_fn = config.hidden_act
self.config = config
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -288,6 +333,14 @@ class TFMPNetIntermediate(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
# Copied from transformers.models.bert.modeling_tf_bert.TFBertOutput with Bert->MPNet
class TFMPNetOutput(tf.keras.layers.Layer):
@ -299,6 +352,7 @@ class TFMPNetOutput(tf.keras.layers.Layer):
)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
self.config = config
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -307,6 +361,17 @@ class TFMPNetOutput(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.intermediate_size])
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
class TFMPNetLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
@ -329,6 +394,20 @@ class TFMPNetLayer(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "attention", None) is not None:
with tf.name_scope(self.attention.name):
self.attention.build(None)
if getattr(self, "intermediate", None) is not None:
with tf.name_scope(self.intermediate.name):
self.intermediate.build(None)
if getattr(self, "out", None) is not None:
with tf.name_scope(self.out.name):
self.out.build(None)
class TFMPNetEncoder(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
@ -344,15 +423,20 @@ class TFMPNetEncoder(tf.keras.layers.Layer):
self.layer = [TFMPNetLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)]
self.relative_attention_num_buckets = config.relative_attention_num_buckets
def build(self, input_shape):
def build(self, input_shape=None):
if self.built:
return
self.built = True
with tf.name_scope("relative_attention_bias"):
self.relative_attention_bias = self.add_weight(
name="embeddings",
shape=[self.relative_attention_num_buckets, self.n_heads],
initializer=get_initializer(self.initializer_range),
)
return super().build(input_shape)
if getattr(self, "layer", None) is not None:
for layer in self.layer:
with tf.name_scope(layer.name):
layer.build(None)
def call(
self,
@ -561,6 +645,20 @@ class TFMPNetMainLayer(tf.keras.layers.Layer):
attentions=encoder_outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
if getattr(self, "pooler", None) is not None:
with tf.name_scope(self.pooler.name):
self.pooler.build(None)
if getattr(self, "embeddings", None) is not None:
with tf.name_scope(self.embeddings.name):
self.embeddings.build(None)
MPNET_START_DOCSTRING = r"""
@ -693,6 +791,14 @@ class TFMPNetModel(TFMPNetPreTrainedModel):
)
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "mpnet", None) is not None:
with tf.name_scope(self.mpnet.name):
self.mpnet.build(None)
class TFMPNetLMHead(tf.keras.layers.Layer):
"""MPNet head for masked and permuted language modeling"""
@ -712,10 +818,18 @@ class TFMPNetLMHead(tf.keras.layers.Layer):
# an output-only bias for each token.
self.decoder = input_embeddings
def build(self, input_shape):
def build(self, input_shape=None):
self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias")
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
if getattr(self, "layer_norm", None) is not None:
with tf.name_scope(self.layer_norm.name):
self.layer_norm.build([None, None, self.config.hidden_size])
def get_output_embeddings(self):
return self.decoder
@ -816,6 +930,17 @@ class TFMPNetForMaskedLM(TFMPNetPreTrainedModel, TFMaskedLanguageModelingLoss):
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "mpnet", None) is not None:
with tf.name_scope(self.mpnet.name):
self.mpnet.build(None)
if getattr(self, "lm_head", None) is not None:
with tf.name_scope(self.lm_head.name):
self.lm_head.build(None)
class TFMPNetClassificationHead(tf.keras.layers.Layer):
"""Head for sentence-level classification tasks."""
@ -832,6 +957,7 @@ class TFMPNetClassificationHead(tf.keras.layers.Layer):
self.out_proj = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="out_proj"
)
self.config = config
def call(self, features, training=False):
x = features[:, 0, :] # take <s> token (equiv. to [CLS])
@ -841,6 +967,17 @@ class TFMPNetClassificationHead(tf.keras.layers.Layer):
x = self.out_proj(x)
return x
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
if getattr(self, "out_proj", None) is not None:
with tf.name_scope(self.out_proj.name):
self.out_proj.build([None, None, self.config.hidden_size])
@add_start_docstrings(
"""
@ -913,6 +1050,17 @@ class TFMPNetForSequenceClassification(TFMPNetPreTrainedModel, TFSequenceClassif
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "mpnet", None) is not None:
with tf.name_scope(self.mpnet.name):
self.mpnet.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build(None)
@add_start_docstrings(
"""
@ -930,6 +1078,7 @@ class TFMPNetForMultipleChoice(TFMPNetPreTrainedModel, TFMultipleChoiceLoss):
self.classifier = tf.keras.layers.Dense(
1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(MPNET_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
@ -999,6 +1148,17 @@ class TFMPNetForMultipleChoice(TFMPNetPreTrainedModel, TFMultipleChoiceLoss):
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "mpnet", None) is not None:
with tf.name_scope(self.mpnet.name):
self.mpnet.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.hidden_size])
@add_start_docstrings(
"""
@ -1019,6 +1179,7 @@ class TFMPNetForTokenClassification(TFMPNetPreTrainedModel, TFTokenClassificatio
self.classifier = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(MPNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ -1073,6 +1234,17 @@ class TFMPNetForTokenClassification(TFMPNetPreTrainedModel, TFTokenClassificatio
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "mpnet", None) is not None:
with tf.name_scope(self.mpnet.name):
self.mpnet.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.hidden_size])
@add_start_docstrings(
"""
@ -1092,6 +1264,7 @@ class TFMPNetForQuestionAnswering(TFMPNetPreTrainedModel, TFQuestionAnsweringLos
self.qa_outputs = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(MPNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ -1159,3 +1332,14 @@ class TFMPNetForQuestionAnswering(TFMPNetPreTrainedModel, TFQuestionAnsweringLos
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "mpnet", None) is not None:
with tf.name_scope(self.mpnet.name):
self.mpnet.build(None)
if getattr(self, "qa_outputs", None) is not None:
with tf.name_scope(self.qa_outputs.name):
self.qa_outputs.build([None, None, self.config.hidden_size])

View File

@ -78,6 +78,7 @@ class TFAttention(tf.keras.layers.Layer):
self.c_proj = TFConv1D(n_state, nx, initializer_range=config.initializer_range, name="c_proj")
self.attn_dropout = tf.keras.layers.Dropout(config.attn_pdrop)
self.resid_dropout = tf.keras.layers.Dropout(config.resid_pdrop)
self.n_state = n_state
self.pruned_heads = set()
def prune_heads(self, heads):
@ -153,6 +154,17 @@ class TFAttention(tf.keras.layers.Layer):
outputs = [a] + attn_outputs[1:]
return outputs # a, (attentions)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "c_attn", None) is not None:
with tf.name_scope(self.c_attn.name):
self.c_attn.build([None, None, self.n_state * 3])
if getattr(self, "c_proj", None) is not None:
with tf.name_scope(self.c_proj.name):
self.c_proj.build([None, None, self.n_state])
class TFMLP(tf.keras.layers.Layer):
def __init__(self, n_state, config, **kwargs):
@ -162,6 +174,8 @@ class TFMLP(tf.keras.layers.Layer):
self.c_proj = TFConv1D(nx, n_state, initializer_range=config.initializer_range, name="c_proj")
self.act = get_tf_activation("gelu")
self.dropout = tf.keras.layers.Dropout(config.resid_pdrop)
self.nx = nx
self.n_state = n_state
def call(self, x, training=False):
h = self.act(self.c_fc(x))
@ -169,6 +183,17 @@ class TFMLP(tf.keras.layers.Layer):
h2 = self.dropout(h2, training=training)
return h2
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "c_fc", None) is not None:
with tf.name_scope(self.c_fc.name):
self.c_fc.build([None, None, self.n_state])
if getattr(self, "c_proj", None) is not None:
with tf.name_scope(self.c_proj.name):
self.c_proj.build([None, None, self.nx])
class TFBlock(tf.keras.layers.Layer):
def __init__(self, config, scale=False, **kwargs):
@ -178,6 +203,7 @@ class TFBlock(tf.keras.layers.Layer):
self.ln_1 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_1")
self.mlp = TFMLP(4 * nx, config, name="mlp")
self.ln_2 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_2")
self.nx = nx
def call(self, x, attention_mask, head_mask, output_attentions, training=False):
output_attn = self.attn(x, attention_mask, head_mask, output_attentions, training=training)
@ -190,6 +216,23 @@ class TFBlock(tf.keras.layers.Layer):
outputs = [h] + output_attn[1:]
return outputs # x, (attentions)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "attn", None) is not None:
with tf.name_scope(self.attn.name):
self.attn.build(None)
if getattr(self, "ln_1", None) is not None:
with tf.name_scope(self.ln_1.name):
self.ln_1.build([None, None, self.nx])
if getattr(self, "mlp", None) is not None:
with tf.name_scope(self.mlp.name):
self.mlp.build(None)
if getattr(self, "ln_2", None) is not None:
with tf.name_scope(self.ln_2.name):
self.ln_2.build([None, None, self.nx])
@keras_serializable
class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
@ -213,7 +256,7 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
self.drop = tf.keras.layers.Dropout(config.embd_pdrop)
self.h = [TFBlock(config, scale=True, name=f"h_._{i}") for i in range(config.n_layer)]
def build(self, input_shape):
def build(self, input_shape=None):
with tf.name_scope("positions_embed"):
self.positions_embed = self.add_weight(
name="embeddings",
@ -221,7 +264,16 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "tokens_embed", None) is not None:
with tf.name_scope(self.tokens_embed.name):
self.tokens_embed.build(None)
if getattr(self, "h", None) is not None:
for layer in self.h:
with tf.name_scope(layer.name):
layer.build(None)
def get_input_embeddings(self):
return self.tokens_embed
@ -528,6 +580,14 @@ class TFOpenAIGPTModel(TFOpenAIGPTPreTrainedModel):
)
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "transformer", None) is not None:
with tf.name_scope(self.transformer.name):
self.transformer.build(None)
@add_start_docstrings(
"""
@ -613,6 +673,14 @@ class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel, TFCausalLanguageModelin
def prepare_inputs_for_generation(self, inputs, **kwargs):
return {"input_ids": inputs}
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "transformer", None) is not None:
with tf.name_scope(self.transformer.name):
self.transformer.build(None)
@add_start_docstrings(
"""
@ -734,6 +802,17 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
"mc_token_ids": tf.TensorSpec((None, None), tf.int32, name="token_type_ids"),
}
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "transformer", None) is not None:
with tf.name_scope(self.transformer.name):
self.transformer.build(None)
if getattr(self, "multiple_choice_head", None) is not None:
with tf.name_scope(self.multiple_choice_head.name):
self.multiple_choice_head.build(None)
@add_start_docstrings(
"""
@ -761,6 +840,7 @@ class TFOpenAIGPTForSequenceClassification(TFOpenAIGPTPreTrainedModel, TFSequenc
use_bias=False,
)
self.transformer = TFOpenAIGPTMainLayer(config, name="transformer")
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING)
@ -848,3 +928,14 @@ class TFOpenAIGPTForSequenceClassification(TFOpenAIGPTPreTrainedModel, TFSequenc
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "score", None) is not None:
with tf.name_scope(self.score.name):
self.score.build([None, None, self.config.n_embd])
if getattr(self, "transformer", None) is not None:
with tf.name_scope(self.transformer.name):
self.transformer.build(None)

View File

@ -268,6 +268,23 @@ class TFOPTAttention(tf.keras.layers.Layer):
return attn_output, attn_weights, past_key_value
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "k_proj", None) is not None:
with tf.name_scope(self.k_proj.name):
self.k_proj.build([None, None, self.embed_dim])
if getattr(self, "q_proj", None) is not None:
with tf.name_scope(self.q_proj.name):
self.q_proj.build([None, None, self.embed_dim])
if getattr(self, "v_proj", None) is not None:
with tf.name_scope(self.v_proj.name):
self.v_proj.build([None, None, self.embed_dim])
if getattr(self, "out_proj", None) is not None:
with tf.name_scope(self.out_proj.name):
self.out_proj.build([None, None, self.embed_dim])
class TFOPTDecoderLayer(tf.keras.layers.Layer):
def __init__(self, config: OPTConfig, **kwargs):
@ -288,6 +305,7 @@ class TFOPTDecoderLayer(tf.keras.layers.Layer):
self.fc1 = tf.keras.layers.Dense(config.ffn_dim, name="fc1")
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
self.config = config
def call(
self,
@ -354,6 +372,26 @@ class TFOPTDecoderLayer(tf.keras.layers.Layer):
return (hidden_states, self_attn_weights, present_key_value)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "self_attn", None) is not None:
with tf.name_scope(self.self_attn.name):
self.self_attn.build(None)
if getattr(self, "self_attn_layer_norm", None) is not None:
with tf.name_scope(self.self_attn_layer_norm.name):
self.self_attn_layer_norm.build([None, None, self.embed_dim])
if getattr(self, "fc1", None) is not None:
with tf.name_scope(self.fc1.name):
self.fc1.build([None, None, self.embed_dim])
if getattr(self, "fc2", None) is not None:
with tf.name_scope(self.fc2.name):
self.fc2.build([None, None, self.config.ffn_dim])
if getattr(self, "final_layer_norm", None) is not None:
with tf.name_scope(self.final_layer_norm.name):
self.final_layer_norm.build([None, None, self.embed_dim])
OPT_START_DOCSTRING = r"""
This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
@ -696,6 +734,30 @@ class TFOPTDecoder(tf.keras.layers.Layer):
attentions=all_self_attns,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embed_tokens", None) is not None:
with tf.name_scope(self.embed_tokens.name):
self.embed_tokens.build(None)
if getattr(self, "embed_positions", None) is not None:
with tf.name_scope(self.embed_positions.name):
self.embed_positions.build(None)
if getattr(self, "final_layer_norm", None) is not None:
with tf.name_scope(self.final_layer_norm.name):
self.final_layer_norm.build([None, None, self.config.hidden_size])
if getattr(self, "project_out", None) is not None:
with tf.name_scope(self.project_out.name):
self.project_out.build([None, None, self.config.hidden_size])
if getattr(self, "project_in", None) is not None:
with tf.name_scope(self.project_in.name):
self.project_in.build([None, None, self.config.word_embed_proj_dim])
if getattr(self, "layers", None) is not None:
for layer in self.layers:
with tf.name_scope(layer.name):
layer.build(None)
@keras_serializable
class TFOPTMainLayer(tf.keras.layers.Layer):
@ -757,6 +819,14 @@ class TFOPTMainLayer(tf.keras.layers.Layer):
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "decoder", None) is not None:
with tf.name_scope(self.decoder.name):
self.decoder.build(None)
@add_start_docstrings(
"The bare TF OPT Model outputting raw hidden-states without any specific head on top.",
@ -841,6 +911,14 @@ class TFOPTModel(TFOPTPreTrainedModel):
attentions=attns,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "model", None) is not None:
with tf.name_scope(self.model.name):
self.model.build(None)
@add_start_docstrings(
"""
@ -1006,3 +1084,11 @@ class TFOPTForCausalLM(TFOPTPreTrainedModel, TFCausalLanguageModelingLoss):
loss=output.loss,
logits=output.logits,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "model", None) is not None:
with tf.name_scope(self.model.name):
self.model.build(None)

View File

@ -41,7 +41,6 @@ from ...modeling_tf_utils import (
)
from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
from ...utils import (
ContextManagers,
add_code_sample_docstrings,
add_end_docstrings,
add_start_docstrings,
@ -330,6 +329,23 @@ class TFPegasusAttention(tf.keras.layers.Layer):
return attn_output, attn_weights, past_key_value
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "k_proj", None) is not None:
with tf.name_scope(self.k_proj.name):
self.k_proj.build([None, None, self.embed_dim])
if getattr(self, "q_proj", None) is not None:
with tf.name_scope(self.q_proj.name):
self.q_proj.build([None, None, self.embed_dim])
if getattr(self, "v_proj", None) is not None:
with tf.name_scope(self.v_proj.name):
self.v_proj.build([None, None, self.embed_dim])
if getattr(self, "out_proj", None) is not None:
with tf.name_scope(self.out_proj.name):
self.out_proj.build([None, None, self.embed_dim])
# Copied from transformers.models.mbart.modeling_tf_mbart.TFMBartEncoderLayer with MBart->Pegasus
class TFPegasusEncoderLayer(tf.keras.layers.Layer):
@ -346,6 +362,7 @@ class TFPegasusEncoderLayer(tf.keras.layers.Layer):
self.fc1 = tf.keras.layers.Dense(config.encoder_ffn_dim, name="fc1")
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
self.config = config
def call(
self,
@ -387,6 +404,26 @@ class TFPegasusEncoderLayer(tf.keras.layers.Layer):
return hidden_states, self_attn_weights
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "self_attn", None) is not None:
with tf.name_scope(self.self_attn.name):
self.self_attn.build(None)
if getattr(self, "self_attn_layer_norm", None) is not None:
with tf.name_scope(self.self_attn_layer_norm.name):
self.self_attn_layer_norm.build([None, None, self.embed_dim])
if getattr(self, "fc1", None) is not None:
with tf.name_scope(self.fc1.name):
self.fc1.build([None, None, self.embed_dim])
if getattr(self, "fc2", None) is not None:
with tf.name_scope(self.fc2.name):
self.fc2.build([None, None, self.config.encoder_ffn_dim])
if getattr(self, "final_layer_norm", None) is not None:
with tf.name_scope(self.final_layer_norm.name):
self.final_layer_norm.build([None, None, self.embed_dim])
# Copied from transformers.models.mbart.modeling_tf_mbart.TFMBartDecoderLayer with MBart->Pegasus
class TFPegasusDecoderLayer(tf.keras.layers.Layer):
@ -416,6 +453,7 @@ class TFPegasusDecoderLayer(tf.keras.layers.Layer):
self.fc1 = tf.keras.layers.Dense(config.decoder_ffn_dim, name="fc1")
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
self.config = config
def call(
self,
@ -497,6 +535,32 @@ class TFPegasusDecoderLayer(tf.keras.layers.Layer):
present_key_value,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "self_attn", None) is not None:
with tf.name_scope(self.self_attn.name):
self.self_attn.build(None)
if getattr(self, "self_attn_layer_norm", None) is not None:
with tf.name_scope(self.self_attn_layer_norm.name):
self.self_attn_layer_norm.build([None, None, self.embed_dim])
if getattr(self, "encoder_attn", None) is not None:
with tf.name_scope(self.encoder_attn.name):
self.encoder_attn.build(None)
if getattr(self, "encoder_attn_layer_norm", None) is not None:
with tf.name_scope(self.encoder_attn_layer_norm.name):
self.encoder_attn_layer_norm.build([None, None, self.embed_dim])
if getattr(self, "fc1", None) is not None:
with tf.name_scope(self.fc1.name):
self.fc1.build([None, None, self.embed_dim])
if getattr(self, "fc2", None) is not None:
with tf.name_scope(self.fc2.name):
self.fc2.build([None, None, self.config.decoder_ffn_dim])
if getattr(self, "final_layer_norm", None) is not None:
with tf.name_scope(self.final_layer_norm.name):
self.final_layer_norm.build([None, None, self.embed_dim])
class TFPegasusPreTrainedModel(TFPreTrainedModel):
config_class = PegasusConfig
@ -747,16 +811,8 @@ class TFPegasusEncoder(tf.keras.layers.Layer):
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None:
# if `self.embed_tokens.load_weight_prefix` is set, runs the embedding operation with the correct name
# scope, so that its weights are registered with the desired name for loading/storing. When `tf.name_scope`
# is used with a name ending in `/`, that name replaces the current name scope.
# (embeddings with tf.name_scope: self.embed_tokens.load_weight_prefix/self.embed_tokens.name/embeddings:0)
context = []
if hasattr(self.embed_tokens, "load_weight_prefix"):
context.append(tf.name_scope(self.embed_tokens.load_weight_prefix + "/"))
with ContextManagers(context):
check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim)
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim)
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
embed_pos = self.embed_positions(input_shape)
hidden_states = inputs_embeds + embed_pos
@ -812,6 +868,21 @@ class TFPegasusEncoder(tf.keras.layers.Layer):
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embed_positions", None) is not None:
with tf.name_scope(self.embed_positions.name):
self.embed_positions.build(None)
if getattr(self, "layer_norm", None) is not None:
with tf.name_scope(self.layer_norm.name):
self.layer_norm.build([None, None, self.config.d_model])
if getattr(self, "layers", None) is not None:
for layer in self.layers:
with tf.name_scope(layer.name):
layer.build(None)
@keras_serializable
class TFPegasusDecoder(tf.keras.layers.Layer):
@ -953,16 +1024,8 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
positions = self.embed_positions(input_shape, position_ids=position_ids)
if inputs_embeds is None:
# if `self.embed_tokens.load_weight_prefix` is set, runs the embedding operation with the correct name
# scope, so that its weights are registered with the desired name for loading/storing. When `tf.name_scope`
# is used with a name ending in `/`, that name replaces the current name scope.
# (embeddings with tf.name_scope: self.embed_tokens.load_weight_prefix/self.embed_tokens.name/embeddings:0)
context = []
if hasattr(self.embed_tokens, "load_weight_prefix"):
context.append(tf.name_scope(self.embed_tokens.load_weight_prefix + "/"))
with ContextManagers(context):
check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim)
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim)
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
hidden_states = inputs_embeds
@ -1047,6 +1110,21 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
cross_attentions=all_cross_attns,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embed_positions", None) is not None:
with tf.name_scope(self.embed_positions.name):
self.embed_positions.build(None)
if getattr(self, "layer_norm", None) is not None:
with tf.name_scope(self.layer_norm.name):
self.layer_norm.build([None, None, self.config.d_model])
if getattr(self, "layers", None) is not None:
for layer in self.layers:
with tf.name_scope(layer.name):
layer.build(None)
@keras_serializable
class TFPegasusMainLayer(tf.keras.layers.Layer):
@ -1158,6 +1236,22 @@ class TFPegasusMainLayer(tf.keras.layers.Layer):
encoder_attentions=encoder_outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
# The shared/tied weights expect to be in the model base namespace
# Adding "/" to the end (not the start!) of a tf.name_scope puts it in the root namespace rather than
# the current one.
with tf.name_scope(self.shared.load_weight_prefix + "/" + self.shared.name + "/"):
self.shared.build(None)
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
if getattr(self, "decoder", None) is not None:
with tf.name_scope(self.decoder.name):
self.decoder.build(None)
@add_start_docstrings(
"The bare PEGASUS Model outputting raw hidden-states without any specific head on top.",
@ -1245,6 +1339,14 @@ class TFPegasusModel(TFPegasusPreTrainedModel):
encoder_attentions=enc_attns,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "model", None) is not None:
with tf.name_scope(self.model.name):
self.model.build(None)
# Copied from transformers.models.bart.modeling_tf_bart.BiasLayer
class BiasLayer(tf.keras.layers.Layer):
@ -1452,3 +1554,14 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "model", None) is not None:
with tf.name_scope(self.model.name):
self.model.build(None)
if getattr(self, "bias_layer", None) is not None:
with tf.name_scope(self.bias_layer.name):
self.bias_layer.build(None)

View File

@ -1292,6 +1292,14 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
return loss
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "rag", None) is not None:
with tf.name_scope(self.rag.name):
self.rag.build(None)
@add_start_docstrings_to_model_forward(
"""
@ -1743,3 +1751,11 @@ class TFRagSequenceForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingL
output = tf.convert_to_tensor(output)
return tf.cast(output, tensors[0][0][0].dtype)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "rag", None) is not None:
with tf.name_scope(self.rag.name):
self.rag.build(None)

View File

@ -53,6 +53,7 @@ TF_REGNET_PRETRAINED_MODEL_ARCHIVE_LIST = [
class TFRegNetConvLayer(tf.keras.layers.Layer):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
stride: int = 1,
@ -75,6 +76,8 @@ class TFRegNetConvLayer(tf.keras.layers.Layer):
)
self.normalization = tf.keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.9, name="normalization")
self.activation = ACT2FN[activation] if activation is not None else tf.identity
self.in_channels = in_channels
self.out_channels = out_channels
def call(self, hidden_state):
hidden_state = self.convolution(self.padding(hidden_state))
@ -82,6 +85,17 @@ class TFRegNetConvLayer(tf.keras.layers.Layer):
hidden_state = self.activation(hidden_state)
return hidden_state
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "convolution", None) is not None:
with tf.name_scope(self.convolution.name):
self.convolution.build([None, None, None, self.in_channels])
if getattr(self, "normalization", None) is not None:
with tf.name_scope(self.normalization.name):
self.normalization.build([None, None, None, self.out_channels])
class TFRegNetEmbeddings(tf.keras.layers.Layer):
"""
@ -92,6 +106,7 @@ class TFRegNetEmbeddings(tf.keras.layers.Layer):
super().__init__(**kwargs)
self.num_channels = config.num_channels
self.embedder = TFRegNetConvLayer(
in_channels=config.num_channels,
out_channels=config.embedding_size,
kernel_size=3,
stride=2,
@ -113,6 +128,14 @@ class TFRegNetEmbeddings(tf.keras.layers.Layer):
hidden_state = self.embedder(pixel_values)
return hidden_state
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embedder", None) is not None:
with tf.name_scope(self.embedder.name):
self.embedder.build(None)
class TFRegNetShortCut(tf.keras.layers.Layer):
"""
@ -120,16 +143,29 @@ class TFRegNetShortCut(tf.keras.layers.Layer):
downsample the input using `stride=2`.
"""
def __init__(self, out_channels: int, stride: int = 2, **kwargs):
def __init__(self, in_channels: int, out_channels: int, stride: int = 2, **kwargs):
super().__init__(**kwargs)
self.convolution = tf.keras.layers.Conv2D(
filters=out_channels, kernel_size=1, strides=stride, use_bias=False, name="convolution"
)
self.normalization = tf.keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.9, name="normalization")
self.in_channels = in_channels
self.out_channels = out_channels
def call(self, inputs: tf.Tensor, training: bool = False) -> tf.Tensor:
return self.normalization(self.convolution(inputs), training=training)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "convolution", None) is not None:
with tf.name_scope(self.convolution.name):
self.convolution.build([None, None, None, self.in_channels])
if getattr(self, "normalization", None) is not None:
with tf.name_scope(self.normalization.name):
self.normalization.build([None, None, None, self.out_channels])
class TFRegNetSELayer(tf.keras.layers.Layer):
"""
@ -143,6 +179,8 @@ class TFRegNetSELayer(tf.keras.layers.Layer):
tf.keras.layers.Conv2D(filters=reduced_channels, kernel_size=1, activation="relu", name="attention.0"),
tf.keras.layers.Conv2D(filters=in_channels, kernel_size=1, activation="sigmoid", name="attention.2"),
]
self.in_channels = in_channels
self.reduced_channels = reduced_channels
def call(self, hidden_state):
# [batch_size, h, w, num_channels] -> [batch_size, 1, 1, num_channels]
@ -152,6 +190,19 @@ class TFRegNetSELayer(tf.keras.layers.Layer):
hidden_state = hidden_state * pooled
return hidden_state
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "pooler", None) is not None:
with tf.name_scope(self.pooler.name):
self.pooler.build((None, None, None, None))
if getattr(self, "attention", None) is not None:
with tf.name_scope(self.attention[0].name):
self.attention[0].build([None, None, None, self.in_channels])
with tf.name_scope(self.attention[1].name):
self.attention[1].build([None, None, None, self.reduced_channels])
class TFRegNetXLayer(tf.keras.layers.Layer):
"""
@ -163,17 +214,17 @@ class TFRegNetXLayer(tf.keras.layers.Layer):
should_apply_shortcut = in_channels != out_channels or stride != 1
groups = max(1, out_channels // config.groups_width)
self.shortcut = (
TFRegNetShortCut(out_channels, stride=stride, name="shortcut")
TFRegNetShortCut(in_channels, out_channels, stride=stride, name="shortcut")
if should_apply_shortcut
else tf.keras.layers.Activation("linear", name="shortcut")
)
# `self.layers` instead of `self.layer` because that is a reserved argument.
self.layers = [
TFRegNetConvLayer(out_channels, kernel_size=1, activation=config.hidden_act, name="layer.0"),
TFRegNetConvLayer(in_channels, out_channels, kernel_size=1, activation=config.hidden_act, name="layer.0"),
TFRegNetConvLayer(
out_channels, stride=stride, groups=groups, activation=config.hidden_act, name="layer.1"
out_channels, out_channels, stride=stride, groups=groups, activation=config.hidden_act, name="layer.1"
),
TFRegNetConvLayer(out_channels, kernel_size=1, activation=None, name="layer.2"),
TFRegNetConvLayer(out_channels, out_channels, kernel_size=1, activation=None, name="layer.2"),
]
self.activation = ACT2FN[config.hidden_act]
@ -186,6 +237,18 @@ class TFRegNetXLayer(tf.keras.layers.Layer):
hidden_state = self.activation(hidden_state)
return hidden_state
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "shortcut", None) is not None:
with tf.name_scope(self.shortcut.name):
self.shortcut.build(None)
if getattr(self, "layers", None) is not None:
for layer in self.layers:
with tf.name_scope(layer.name):
layer.build(None)
class TFRegNetYLayer(tf.keras.layers.Layer):
"""
@ -197,17 +260,17 @@ class TFRegNetYLayer(tf.keras.layers.Layer):
should_apply_shortcut = in_channels != out_channels or stride != 1
groups = max(1, out_channels // config.groups_width)
self.shortcut = (
TFRegNetShortCut(out_channels, stride=stride, name="shortcut")
TFRegNetShortCut(in_channels, out_channels, stride=stride, name="shortcut")
if should_apply_shortcut
else tf.keras.layers.Activation("linear", name="shortcut")
)
self.layers = [
TFRegNetConvLayer(out_channels, kernel_size=1, activation=config.hidden_act, name="layer.0"),
TFRegNetConvLayer(in_channels, out_channels, kernel_size=1, activation=config.hidden_act, name="layer.0"),
TFRegNetConvLayer(
out_channels, stride=stride, groups=groups, activation=config.hidden_act, name="layer.1"
out_channels, out_channels, stride=stride, groups=groups, activation=config.hidden_act, name="layer.1"
),
TFRegNetSELayer(out_channels, reduced_channels=int(round(in_channels / 4)), name="layer.2"),
TFRegNetConvLayer(out_channels, kernel_size=1, activation=None, name="layer.3"),
TFRegNetConvLayer(out_channels, out_channels, kernel_size=1, activation=None, name="layer.3"),
]
self.activation = ACT2FN[config.hidden_act]
@ -220,6 +283,18 @@ class TFRegNetYLayer(tf.keras.layers.Layer):
hidden_state = self.activation(hidden_state)
return hidden_state
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "shortcut", None) is not None:
with tf.name_scope(self.shortcut.name):
self.shortcut.build(None)
if getattr(self, "layers", None) is not None:
for layer in self.layers:
with tf.name_scope(layer.name):
layer.build(None)
class TFRegNetStage(tf.keras.layers.Layer):
"""
@ -243,6 +318,15 @@ class TFRegNetStage(tf.keras.layers.Layer):
hidden_state = layer_module(hidden_state)
return hidden_state
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "layers", None) is not None:
for layer in self.layers:
with tf.name_scope(layer.name):
layer.build(None)
class TFRegNetEncoder(tf.keras.layers.Layer):
def __init__(self, config: RegNetConfig, **kwargs):
@ -282,6 +366,14 @@ class TFRegNetEncoder(tf.keras.layers.Layer):
return TFBaseModelOutputWithNoAttention(last_hidden_state=hidden_state, hidden_states=hidden_states)
def build(self, input_shape=None):
if self.built:
return
self.built = True
for stage in self.stages:
with tf.name_scope(stage.name):
stage.build(None)
@keras_serializable
class TFRegNetMainLayer(tf.keras.layers.Layer):
@ -333,6 +425,20 @@ class TFRegNetMainLayer(tf.keras.layers.Layer):
hidden_states=hidden_states if output_hidden_states else encoder_outputs.hidden_states,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embedder", None) is not None:
with tf.name_scope(self.embedder.name):
self.embedder.build(None)
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
if getattr(self, "pooler", None) is not None:
with tf.name_scope(self.pooler.name):
self.pooler.build((None, None, None, None))
class TFRegNetPreTrainedModel(TFPreTrainedModel):
"""
@ -418,6 +524,14 @@ class TFRegNetModel(TFRegNetPreTrainedModel):
hidden_states=outputs.hidden_states,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "regnet", None) is not None:
with tf.name_scope(self.regnet.name):
self.regnet.build(None)
@add_start_docstrings(
"""
@ -479,3 +593,14 @@ class TFRegNetForImageClassification(TFRegNetPreTrainedModel, TFSequenceClassifi
return ((loss,) + output) if loss is not None else output
return TFSequenceClassifierOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "regnet", None) is not None:
with tf.name_scope(self.regnet.name):
self.regnet.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier[1].name):
self.classifier[1].build([None, None, None, self.config.hidden_sizes[-1]])

View File

@ -80,7 +80,7 @@ class TFRemBertEmbeddings(tf.keras.layers.Layer):
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
def build(self, input_shape: tf.TensorShape):
def build(self, input_shape=None):
with tf.name_scope("word_embeddings"):
self.weight = self.add_weight(
name="weight",
@ -102,7 +102,12 @@ class TFRemBertEmbeddings(tf.keras.layers.Layer):
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.input_embedding_size])
def call(
self,
@ -172,6 +177,7 @@ class TFRemBertSelfAttention(tf.keras.layers.Layer):
self.dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
self.is_decoder = config.is_decoder
self.config = config
def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
@ -261,6 +267,20 @@ class TFRemBertSelfAttention(tf.keras.layers.Layer):
outputs = outputs + (past_key_value,)
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "query", None) is not None:
with tf.name_scope(self.query.name):
self.query.build([None, None, self.config.hidden_size])
if getattr(self, "key", None) is not None:
with tf.name_scope(self.key.name):
self.key.build([None, None, self.config.hidden_size])
if getattr(self, "value", None) is not None:
with tf.name_scope(self.value.name):
self.value.build([None, None, self.config.hidden_size])
# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfOutput with Bert->RemBert
class TFRemBertSelfOutput(tf.keras.layers.Layer):
@ -272,6 +292,7 @@ class TFRemBertSelfOutput(tf.keras.layers.Layer):
)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
self.config = config
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -280,6 +301,17 @@ class TFRemBertSelfOutput(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
# Copied from transformers.models.bert.modeling_tf_bert.TFBertAttention with Bert->RemBert
class TFRemBertAttention(tf.keras.layers.Layer):
@ -321,6 +353,17 @@ class TFRemBertAttention(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "self_attention", None) is not None:
with tf.name_scope(self.self_attention.name):
self.self_attention.build(None)
if getattr(self, "dense_output", None) is not None:
with tf.name_scope(self.dense_output.name):
self.dense_output.build(None)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertIntermediate with Bert->RemBert
class TFRemBertIntermediate(tf.keras.layers.Layer):
@ -335,6 +378,7 @@ class TFRemBertIntermediate(tf.keras.layers.Layer):
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
else:
self.intermediate_act_fn = config.hidden_act
self.config = config
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -342,6 +386,14 @@ class TFRemBertIntermediate(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
# Copied from transformers.models.bert.modeling_tf_bert.TFBertOutput with Bert->RemBert
class TFRemBertOutput(tf.keras.layers.Layer):
@ -353,6 +405,7 @@ class TFRemBertOutput(tf.keras.layers.Layer):
)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
self.config = config
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -361,6 +414,17 @@ class TFRemBertOutput(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.intermediate_size])
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLayer with Bert->RemBert
class TFRemBertLayer(tf.keras.layers.Layer):
@ -448,6 +512,23 @@ class TFRemBertLayer(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "attention", None) is not None:
with tf.name_scope(self.attention.name):
self.attention.build(None)
if getattr(self, "intermediate", None) is not None:
with tf.name_scope(self.intermediate.name):
self.intermediate.build(None)
if getattr(self, "bert_output", None) is not None:
with tf.name_scope(self.bert_output.name):
self.bert_output.build(None)
if getattr(self, "crossattention", None) is not None:
with tf.name_scope(self.crossattention.name):
self.crossattention.build(None)
class TFRemBertEncoder(tf.keras.layers.Layer):
def __init__(self, config: RemBertConfig, **kwargs):
@ -524,6 +605,18 @@ class TFRemBertEncoder(tf.keras.layers.Layer):
cross_attentions=all_cross_attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embedding_hidden_mapping_in", None) is not None:
with tf.name_scope(self.embedding_hidden_mapping_in.name):
self.embedding_hidden_mapping_in.build([None, None, self.config.input_embedding_size])
if getattr(self, "layer", None) is not None:
for layer in self.layer:
with tf.name_scope(layer.name):
layer.build(None)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertPooler with Bert->RemBert
class TFRemBertPooler(tf.keras.layers.Layer):
@ -536,6 +629,7 @@ class TFRemBertPooler(tf.keras.layers.Layer):
activation="tanh",
name="dense",
)
self.config = config
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
# We "pool" the model by simply taking the hidden state corresponding
@ -545,6 +639,14 @@ class TFRemBertPooler(tf.keras.layers.Layer):
return pooled_output
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
class TFRemBertLMPredictionHead(tf.keras.layers.Layer):
def __init__(self, config: RemBertConfig, input_embeddings: tf.keras.layers.Layer, **kwargs):
@ -562,7 +664,7 @@ class TFRemBertLMPredictionHead(tf.keras.layers.Layer):
self.activation = config.hidden_act
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
def build(self, input_shape: tf.TensorShape):
def build(self, input_shape=None):
self.decoder = self.add_weight(
name="decoder/weight",
shape=[self.config.vocab_size, self.output_embedding_size],
@ -572,7 +674,15 @@ class TFRemBertLMPredictionHead(tf.keras.layers.Layer):
shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="decoder/bias"
)
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, self.config.output_embedding_size])
def get_output_embeddings(self) -> tf.keras.layers.Layer:
return self
@ -612,6 +722,14 @@ class TFRemBertMLMHead(tf.keras.layers.Layer):
return prediction_scores
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "predictions", None) is not None:
with tf.name_scope(self.predictions.name):
self.predictions.build(None)
@keras_serializable
class TFRemBertMainLayer(tf.keras.layers.Layer):
@ -800,6 +918,20 @@ class TFRemBertMainLayer(tf.keras.layers.Layer):
cross_attentions=encoder_outputs.cross_attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embeddings", None) is not None:
with tf.name_scope(self.embeddings.name):
self.embeddings.build(None)
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
if getattr(self, "pooler", None) is not None:
with tf.name_scope(self.pooler.name):
self.pooler.build(None)
class TFRemBertPreTrainedModel(TFPreTrainedModel):
"""
@ -982,6 +1114,14 @@ class TFRemBertModel(TFRemBertPreTrainedModel):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "rembert", None) is not None:
with tf.name_scope(self.rembert.name):
self.rembert.build(None)
@add_start_docstrings("""RemBERT Model with a `language modeling` head on top.""", REMBERT_START_DOCSTRING)
class TFRemBertForMaskedLM(TFRemBertPreTrainedModel, TFMaskedLanguageModelingLoss):
@ -1054,6 +1194,17 @@ class TFRemBertForMaskedLM(TFRemBertPreTrainedModel, TFMaskedLanguageModelingLos
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "rembert", None) is not None:
with tf.name_scope(self.rembert.name):
self.rembert.build(None)
if getattr(self, "mlm", None) is not None:
with tf.name_scope(self.mlm.name):
self.mlm.build(None)
@add_start_docstrings(
"""RemBERT Model with a `language modeling` head on top for CLM fine-tuning.""", REMBERT_START_DOCSTRING
@ -1170,6 +1321,17 @@ class TFRemBertForCausalLM(TFRemBertPreTrainedModel, TFCausalLanguageModelingLos
cross_attentions=outputs.cross_attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "rembert", None) is not None:
with tf.name_scope(self.rembert.name):
self.rembert.build(None)
if getattr(self, "mlm", None) is not None:
with tf.name_scope(self.mlm.name):
self.mlm.build(None)
@add_start_docstrings(
"""
@ -1190,6 +1352,7 @@ class TFRemBertForSequenceClassification(TFRemBertPreTrainedModel, TFSequenceCla
kernel_initializer=get_initializer(config.initializer_range),
name="classifier",
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ -1246,6 +1409,17 @@ class TFRemBertForSequenceClassification(TFRemBertPreTrainedModel, TFSequenceCla
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "rembert", None) is not None:
with tf.name_scope(self.rembert.name):
self.rembert.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.hidden_size])
@add_start_docstrings(
"""
@ -1263,6 +1437,7 @@ class TFRemBertForMultipleChoice(TFRemBertPreTrainedModel, TFMultipleChoiceLoss)
self.classifier = tf.keras.layers.Dense(
units=1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
@ -1342,6 +1517,17 @@ class TFRemBertForMultipleChoice(TFRemBertPreTrainedModel, TFMultipleChoiceLoss)
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "rembert", None) is not None:
with tf.name_scope(self.rembert.name):
self.rembert.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.hidden_size])
@add_start_docstrings(
"""
@ -1361,6 +1547,7 @@ class TFRemBertForTokenClassification(TFRemBertPreTrainedModel, TFTokenClassific
self.classifier = tf.keras.layers.Dense(
units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ -1415,6 +1602,17 @@ class TFRemBertForTokenClassification(TFRemBertPreTrainedModel, TFTokenClassific
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "rembert", None) is not None:
with tf.name_scope(self.rembert.name):
self.rembert.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.hidden_size])
@add_start_docstrings(
"""
@ -1433,6 +1631,7 @@ class TFRemBertForQuestionAnswering(TFRemBertPreTrainedModel, TFQuestionAnswerin
self.qa_outputs = tf.keras.layers.Dense(
units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ -1501,3 +1700,14 @@ class TFRemBertForQuestionAnswering(TFRemBertPreTrainedModel, TFQuestionAnswerin
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "rembert", None) is not None:
with tf.name_scope(self.rembert.name):
self.rembert.build(None)
if getattr(self, "qa_outputs", None) is not None:
with tf.name_scope(self.qa_outputs.name):
self.qa_outputs.build([None, None, self.config.hidden_size])

View File

@ -51,7 +51,13 @@ TF_RESNET_PRETRAINED_MODEL_ARCHIVE_LIST = [
class TFResNetConvLayer(tf.keras.layers.Layer):
def __init__(
self, out_channels: int, kernel_size: int = 3, stride: int = 1, activation: str = "relu", **kwargs
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
stride: int = 1,
activation: str = "relu",
**kwargs,
) -> None:
super().__init__(**kwargs)
self.pad_value = kernel_size // 2
@ -61,6 +67,8 @@ class TFResNetConvLayer(tf.keras.layers.Layer):
# Use same default momentum and epsilon as PyTorch equivalent
self.normalization = tf.keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.9, name="normalization")
self.activation = ACT2FN[activation] if activation is not None else tf.keras.layers.Activation("linear")
self.in_channels = in_channels
self.out_channels = out_channels
def convolution(self, hidden_state: tf.Tensor) -> tf.Tensor:
# Pad to match that done in the PyTorch Conv2D model
@ -75,6 +83,17 @@ class TFResNetConvLayer(tf.keras.layers.Layer):
hidden_state = self.activation(hidden_state)
return hidden_state
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "conv", None) is not None:
with tf.name_scope(self.conv.name):
self.conv.build([None, None, None, self.in_channels])
if getattr(self, "normalization", None) is not None:
with tf.name_scope(self.normalization.name):
self.normalization.build([None, None, None, self.out_channels])
class TFResNetEmbeddings(tf.keras.layers.Layer):
"""
@ -84,6 +103,7 @@ class TFResNetEmbeddings(tf.keras.layers.Layer):
def __init__(self, config: ResNetConfig, **kwargs) -> None:
super().__init__(**kwargs)
self.embedder = TFResNetConvLayer(
config.num_channels,
config.embedding_size,
kernel_size=7,
stride=2,
@ -105,6 +125,17 @@ class TFResNetEmbeddings(tf.keras.layers.Layer):
hidden_state = self.pooler(hidden_state)
return hidden_state
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embedder", None) is not None:
with tf.name_scope(self.embedder.name):
self.embedder.build(None)
if getattr(self, "pooler", None) is not None:
with tf.name_scope(self.pooler.name):
self.pooler.build(None)
class TFResNetShortCut(tf.keras.layers.Layer):
"""
@ -112,13 +143,15 @@ class TFResNetShortCut(tf.keras.layers.Layer):
downsample the input using `stride=2`.
"""
def __init__(self, out_channels: int, stride: int = 2, **kwargs) -> None:
def __init__(self, in_channels: int, out_channels: int, stride: int = 2, **kwargs) -> None:
super().__init__(**kwargs)
self.convolution = tf.keras.layers.Conv2D(
out_channels, kernel_size=1, strides=stride, use_bias=False, name="convolution"
)
# Use same default momentum and epsilon as PyTorch equivalent
self.normalization = tf.keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.9, name="normalization")
self.in_channels = in_channels
self.out_channels = out_channels
def call(self, x: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_state = x
@ -126,6 +159,17 @@ class TFResNetShortCut(tf.keras.layers.Layer):
hidden_state = self.normalization(hidden_state, training=training)
return hidden_state
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "convolution", None) is not None:
with tf.name_scope(self.convolution.name):
self.convolution.build([None, None, None, self.in_channels])
if getattr(self, "normalization", None) is not None:
with tf.name_scope(self.normalization.name):
self.normalization.build([None, None, None, self.out_channels])
class TFResNetBasicLayer(tf.keras.layers.Layer):
"""
@ -137,10 +181,10 @@ class TFResNetBasicLayer(tf.keras.layers.Layer):
) -> None:
super().__init__(**kwargs)
should_apply_shortcut = in_channels != out_channels or stride != 1
self.conv1 = TFResNetConvLayer(out_channels, stride=stride, name="layer.0")
self.conv2 = TFResNetConvLayer(out_channels, activation=None, name="layer.1")
self.conv1 = TFResNetConvLayer(in_channels, out_channels, stride=stride, name="layer.0")
self.conv2 = TFResNetConvLayer(out_channels, out_channels, activation=None, name="layer.1")
self.shortcut = (
TFResNetShortCut(out_channels, stride=stride, name="shortcut")
TFResNetShortCut(in_channels, out_channels, stride=stride, name="shortcut")
if should_apply_shortcut
else tf.keras.layers.Activation("linear", name="shortcut")
)
@ -155,6 +199,20 @@ class TFResNetBasicLayer(tf.keras.layers.Layer):
hidden_state = self.activation(hidden_state)
return hidden_state
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "conv1", None) is not None:
with tf.name_scope(self.conv1.name):
self.conv1.build(None)
if getattr(self, "conv2", None) is not None:
with tf.name_scope(self.conv2.name):
self.conv2.build(None)
if getattr(self, "shortcut", None) is not None:
with tf.name_scope(self.shortcut.name):
self.shortcut.build(None)
class TFResNetBottleNeckLayer(tf.keras.layers.Layer):
"""
@ -176,11 +234,11 @@ class TFResNetBottleNeckLayer(tf.keras.layers.Layer):
super().__init__(**kwargs)
should_apply_shortcut = in_channels != out_channels or stride != 1
reduces_channels = out_channels // reduction
self.conv0 = TFResNetConvLayer(reduces_channels, kernel_size=1, name="layer.0")
self.conv1 = TFResNetConvLayer(reduces_channels, stride=stride, name="layer.1")
self.conv2 = TFResNetConvLayer(out_channels, kernel_size=1, activation=None, name="layer.2")
self.conv0 = TFResNetConvLayer(in_channels, reduces_channels, kernel_size=1, name="layer.0")
self.conv1 = TFResNetConvLayer(reduces_channels, reduces_channels, stride=stride, name="layer.1")
self.conv2 = TFResNetConvLayer(reduces_channels, out_channels, kernel_size=1, activation=None, name="layer.2")
self.shortcut = (
TFResNetShortCut(out_channels, stride=stride, name="shortcut")
TFResNetShortCut(in_channels, out_channels, stride=stride, name="shortcut")
if should_apply_shortcut
else tf.keras.layers.Activation("linear", name="shortcut")
)
@ -196,6 +254,23 @@ class TFResNetBottleNeckLayer(tf.keras.layers.Layer):
hidden_state = self.activation(hidden_state)
return hidden_state
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "conv0", None) is not None:
with tf.name_scope(self.conv0.name):
self.conv0.build(None)
if getattr(self, "conv1", None) is not None:
with tf.name_scope(self.conv1.name):
self.conv1.build(None)
if getattr(self, "conv2", None) is not None:
with tf.name_scope(self.conv2.name):
self.conv2.build(None)
if getattr(self, "shortcut", None) is not None:
with tf.name_scope(self.shortcut.name):
self.shortcut.build(None)
class TFResNetStage(tf.keras.layers.Layer):
"""
@ -221,6 +296,15 @@ class TFResNetStage(tf.keras.layers.Layer):
hidden_state = layer(hidden_state, training=training)
return hidden_state
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "stage_layers", None) is not None:
for layer in self.stage_layers:
with tf.name_scope(layer.name):
layer.build(None)
class TFResNetEncoder(tf.keras.layers.Layer):
def __init__(self, config: ResNetConfig, **kwargs) -> None:
@ -264,6 +348,15 @@ class TFResNetEncoder(tf.keras.layers.Layer):
return TFBaseModelOutputWithNoAttention(last_hidden_state=hidden_state, hidden_states=hidden_states)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "stages", None) is not None:
for layer in self.stages:
with tf.name_scope(layer.name):
layer.build(None)
class TFResNetPreTrainedModel(TFPreTrainedModel):
"""
@ -364,6 +457,17 @@ class TFResNetMainLayer(tf.keras.layers.Layer):
hidden_states=hidden_states,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embedder", None) is not None:
with tf.name_scope(self.embedder.name):
self.embedder.build(None)
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
@add_start_docstrings(
"The bare ResNet model outputting raw features without any specific head on top.",
@ -403,6 +507,14 @@ class TFResNetModel(TFResNetPreTrainedModel):
)
return resnet_outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "resnet", None) is not None:
with tf.name_scope(self.resnet.name):
self.resnet.build(None)
@add_start_docstrings(
"""
@ -422,6 +534,7 @@ class TFResNetForImageClassification(TFResNetPreTrainedModel, TFSequenceClassifi
if config.num_labels > 0
else tf.keras.layers.Activation("linear", name="classifier.1")
)
self.config = config
def classifier(self, x: tf.Tensor) -> tf.Tensor:
x = tf.keras.layers.Flatten()(x)
@ -466,3 +579,14 @@ class TFResNetForImageClassification(TFResNetPreTrainedModel, TFSequenceClassifi
return (loss,) + output if loss is not None else output
return TFImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "resnet", None) is not None:
with tf.name_scope(self.resnet.name):
self.resnet.build(None)
if getattr(self, "classifier_layer", None) is not None:
with tf.name_scope(self.classifier_layer.name):
self.classifier_layer.build([None, None, self.config.hidden_sizes[-1]])

View File

@ -89,7 +89,7 @@ class TFRobertaEmbeddings(tf.keras.layers.Layer):
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
def build(self, input_shape: tf.TensorShape):
def build(self, input_shape=None):
with tf.name_scope("word_embeddings"):
self.weight = self.add_weight(
name="weight",
@ -111,7 +111,12 @@ class TFRobertaEmbeddings(tf.keras.layers.Layer):
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
def create_position_ids_from_input_ids(self, input_ids, past_key_values_length=0):
"""
@ -184,6 +189,7 @@ class TFRobertaPooler(tf.keras.layers.Layer):
activation="tanh",
name="dense",
)
self.config = config
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
# We "pool" the model by simply taking the hidden state corresponding
@ -193,6 +199,14 @@ class TFRobertaPooler(tf.keras.layers.Layer):
return pooled_output
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfAttention with Bert->Roberta
class TFRobertaSelfAttention(tf.keras.layers.Layer):
@ -222,6 +236,7 @@ class TFRobertaSelfAttention(tf.keras.layers.Layer):
self.dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
self.is_decoder = config.is_decoder
self.config = config
def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
@ -311,6 +326,20 @@ class TFRobertaSelfAttention(tf.keras.layers.Layer):
outputs = outputs + (past_key_value,)
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "query", None) is not None:
with tf.name_scope(self.query.name):
self.query.build([None, None, self.config.hidden_size])
if getattr(self, "key", None) is not None:
with tf.name_scope(self.key.name):
self.key.build([None, None, self.config.hidden_size])
if getattr(self, "value", None) is not None:
with tf.name_scope(self.value.name):
self.value.build([None, None, self.config.hidden_size])
# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfOutput with Bert->Roberta
class TFRobertaSelfOutput(tf.keras.layers.Layer):
@ -322,6 +351,7 @@ class TFRobertaSelfOutput(tf.keras.layers.Layer):
)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
self.config = config
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -330,6 +360,17 @@ class TFRobertaSelfOutput(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
# Copied from transformers.models.bert.modeling_tf_bert.TFBertAttention with Bert->Roberta
class TFRobertaAttention(tf.keras.layers.Layer):
@ -371,6 +412,17 @@ class TFRobertaAttention(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "self_attention", None) is not None:
with tf.name_scope(self.self_attention.name):
self.self_attention.build(None)
if getattr(self, "dense_output", None) is not None:
with tf.name_scope(self.dense_output.name):
self.dense_output.build(None)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertIntermediate with Bert->Roberta
class TFRobertaIntermediate(tf.keras.layers.Layer):
@ -385,6 +437,7 @@ class TFRobertaIntermediate(tf.keras.layers.Layer):
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
else:
self.intermediate_act_fn = config.hidden_act
self.config = config
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -392,6 +445,14 @@ class TFRobertaIntermediate(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
# Copied from transformers.models.bert.modeling_tf_bert.TFBertOutput with Bert->Roberta
class TFRobertaOutput(tf.keras.layers.Layer):
@ -403,6 +464,7 @@ class TFRobertaOutput(tf.keras.layers.Layer):
)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
self.config = config
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -411,6 +473,17 @@ class TFRobertaOutput(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.intermediate_size])
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLayer with Bert->Roberta
class TFRobertaLayer(tf.keras.layers.Layer):
@ -498,6 +571,23 @@ class TFRobertaLayer(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "attention", None) is not None:
with tf.name_scope(self.attention.name):
self.attention.build(None)
if getattr(self, "intermediate", None) is not None:
with tf.name_scope(self.intermediate.name):
self.intermediate.build(None)
if getattr(self, "bert_output", None) is not None:
with tf.name_scope(self.bert_output.name):
self.bert_output.build(None)
if getattr(self, "crossattention", None) is not None:
with tf.name_scope(self.crossattention.name):
self.crossattention.build(None)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertEncoder with Bert->Roberta
class TFRobertaEncoder(tf.keras.layers.Layer):
@ -568,6 +658,15 @@ class TFRobertaEncoder(tf.keras.layers.Layer):
cross_attentions=all_cross_attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "layer", None) is not None:
for layer in self.layer:
with tf.name_scope(layer.name):
layer.build(None)
@keras_serializable
class TFRobertaMainLayer(tf.keras.layers.Layer):
@ -765,6 +864,20 @@ class TFRobertaMainLayer(tf.keras.layers.Layer):
cross_attentions=encoder_outputs.cross_attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
if getattr(self, "pooler", None) is not None:
with tf.name_scope(self.pooler.name):
self.pooler.build(None)
if getattr(self, "embeddings", None) is not None:
with tf.name_scope(self.embeddings.name):
self.embeddings.build(None)
class TFRobertaPreTrainedModel(TFPreTrainedModel):
"""
@ -946,6 +1059,14 @@ class TFRobertaModel(TFRobertaPreTrainedModel):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "roberta", None) is not None:
with tf.name_scope(self.roberta.name):
self.roberta.build(None)
class TFRobertaLMHead(tf.keras.layers.Layer):
"""Roberta Head for masked language modeling."""
@ -965,10 +1086,18 @@ class TFRobertaLMHead(tf.keras.layers.Layer):
# an output-only bias for each token.
self.decoder = input_embeddings
def build(self, input_shape):
def build(self, input_shape=None):
self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias")
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
if getattr(self, "layer_norm", None) is not None:
with tf.name_scope(self.layer_norm.name):
self.layer_norm.build([None, None, self.config.hidden_size])
def get_output_embeddings(self):
return self.decoder
@ -1076,6 +1205,17 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel, TFMaskedLanguageModelingLos
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "roberta", None) is not None:
with tf.name_scope(self.roberta.name):
self.roberta.build(None)
if getattr(self, "lm_head", None) is not None:
with tf.name_scope(self.lm_head.name):
self.lm_head.build(None)
class TFRobertaForCausalLM(TFRobertaPreTrainedModel, TFCausalLanguageModelingLoss):
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
@ -1198,6 +1338,17 @@ class TFRobertaForCausalLM(TFRobertaPreTrainedModel, TFCausalLanguageModelingLos
cross_attentions=outputs.cross_attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "roberta", None) is not None:
with tf.name_scope(self.roberta.name):
self.roberta.build(None)
if getattr(self, "lm_head", None) is not None:
with tf.name_scope(self.lm_head.name):
self.lm_head.build(None)
class TFRobertaClassificationHead(tf.keras.layers.Layer):
"""Head for sentence-level classification tasks."""
@ -1217,6 +1368,7 @@ class TFRobertaClassificationHead(tf.keras.layers.Layer):
self.out_proj = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="out_proj"
)
self.config = config
def call(self, features, training=False):
x = features[:, 0, :] # take <s> token (equiv. to [CLS])
@ -1226,6 +1378,17 @@ class TFRobertaClassificationHead(tf.keras.layers.Layer):
x = self.out_proj(x)
return x
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
if getattr(self, "out_proj", None) is not None:
with tf.name_scope(self.out_proj.name):
self.out_proj.build([None, None, self.config.hidden_size])
@add_start_docstrings(
"""
@ -1302,6 +1465,17 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceCla
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "roberta", None) is not None:
with tf.name_scope(self.roberta.name):
self.roberta.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build(None)
@add_start_docstrings(
"""
@ -1323,6 +1497,7 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss)
self.classifier = tf.keras.layers.Dense(
1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
@ -1392,6 +1567,17 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss)
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "roberta", None) is not None:
with tf.name_scope(self.roberta.name):
self.roberta.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.hidden_size])
@add_start_docstrings(
"""
@ -1417,6 +1603,7 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific
self.classifier = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ -1475,6 +1662,17 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "roberta", None) is not None:
with tf.name_scope(self.roberta.name):
self.roberta.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.hidden_size])
@add_start_docstrings(
"""
@ -1495,6 +1693,7 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin
self.qa_outputs = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ -1566,3 +1765,14 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "roberta", None) is not None:
with tf.name_scope(self.roberta.name):
self.roberta.build(None)
if getattr(self, "qa_outputs", None) is not None:
with tf.name_scope(self.qa_outputs.name):
self.qa_outputs.build([None, None, self.config.hidden_size])

View File

@ -94,7 +94,7 @@ class TFRobertaPreLayerNormEmbeddings(tf.keras.layers.Layer):
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
def build(self, input_shape: tf.TensorShape):
def build(self, input_shape=None):
with tf.name_scope("word_embeddings"):
self.weight = self.add_weight(
name="weight",
@ -116,7 +116,12 @@ class TFRobertaPreLayerNormEmbeddings(tf.keras.layers.Layer):
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
def create_position_ids_from_input_ids(self, input_ids, past_key_values_length=0):
"""
@ -189,6 +194,7 @@ class TFRobertaPreLayerNormPooler(tf.keras.layers.Layer):
activation="tanh",
name="dense",
)
self.config = config
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
# We "pool" the model by simply taking the hidden state corresponding
@ -198,6 +204,14 @@ class TFRobertaPreLayerNormPooler(tf.keras.layers.Layer):
return pooled_output
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfAttention with Bert->RobertaPreLayerNorm
class TFRobertaPreLayerNormSelfAttention(tf.keras.layers.Layer):
@ -227,6 +241,7 @@ class TFRobertaPreLayerNormSelfAttention(tf.keras.layers.Layer):
self.dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
self.is_decoder = config.is_decoder
self.config = config
def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
@ -316,6 +331,20 @@ class TFRobertaPreLayerNormSelfAttention(tf.keras.layers.Layer):
outputs = outputs + (past_key_value,)
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "query", None) is not None:
with tf.name_scope(self.query.name):
self.query.build([None, None, self.config.hidden_size])
if getattr(self, "key", None) is not None:
with tf.name_scope(self.key.name):
self.key.build([None, None, self.config.hidden_size])
if getattr(self, "value", None) is not None:
with tf.name_scope(self.value.name):
self.value.build([None, None, self.config.hidden_size])
class TFRobertaPreLayerNormSelfOutput(tf.keras.layers.Layer):
def __init__(self, config: RobertaPreLayerNormConfig, **kwargs):
@ -325,6 +354,7 @@ class TFRobertaPreLayerNormSelfOutput(tf.keras.layers.Layer):
units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
self.config = config
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -333,6 +363,14 @@ class TFRobertaPreLayerNormSelfOutput(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
class TFRobertaPreLayerNormAttention(tf.keras.layers.Layer):
def __init__(self, config: RobertaPreLayerNormConfig, **kwargs):
@ -341,6 +379,7 @@ class TFRobertaPreLayerNormAttention(tf.keras.layers.Layer):
self.self_attention = TFRobertaPreLayerNormSelfAttention(config, name="self")
self.dense_output = TFRobertaPreLayerNormSelfOutput(config, name="output")
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.config = config
# Copied from transformers.models.bert.modeling_tf_bert.TFBertAttention.prune_heads
def prune_heads(self, heads):
@ -376,6 +415,20 @@ class TFRobertaPreLayerNormAttention(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "self_attention", None) is not None:
with tf.name_scope(self.self_attention.name):
self.self_attention.build(None)
if getattr(self, "dense_output", None) is not None:
with tf.name_scope(self.dense_output.name):
self.dense_output.build(None)
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
class TFRobertaPreLayerNormIntermediate(tf.keras.layers.Layer):
def __init__(self, config: RobertaPreLayerNormConfig, **kwargs):
@ -390,6 +443,7 @@ class TFRobertaPreLayerNormIntermediate(tf.keras.layers.Layer):
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
else:
self.intermediate_act_fn = config.hidden_act
self.config = config
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.LayerNorm(inputs=hidden_states)
@ -398,6 +452,17 @@ class TFRobertaPreLayerNormIntermediate(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
class TFRobertaPreLayerNormOutput(tf.keras.layers.Layer):
def __init__(self, config: RobertaPreLayerNormConfig, **kwargs):
@ -407,6 +472,7 @@ class TFRobertaPreLayerNormOutput(tf.keras.layers.Layer):
units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
self.config = config
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -415,6 +481,14 @@ class TFRobertaPreLayerNormOutput(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.intermediate_size])
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLayer with Bert->RobertaPreLayerNorm
class TFRobertaPreLayerNormLayer(tf.keras.layers.Layer):
@ -502,6 +576,23 @@ class TFRobertaPreLayerNormLayer(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "attention", None) is not None:
with tf.name_scope(self.attention.name):
self.attention.build(None)
if getattr(self, "intermediate", None) is not None:
with tf.name_scope(self.intermediate.name):
self.intermediate.build(None)
if getattr(self, "bert_output", None) is not None:
with tf.name_scope(self.bert_output.name):
self.bert_output.build(None)
if getattr(self, "crossattention", None) is not None:
with tf.name_scope(self.crossattention.name):
self.crossattention.build(None)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertEncoder with Bert->RobertaPreLayerNorm
class TFRobertaPreLayerNormEncoder(tf.keras.layers.Layer):
@ -572,6 +663,15 @@ class TFRobertaPreLayerNormEncoder(tf.keras.layers.Layer):
cross_attentions=all_cross_attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "layer", None) is not None:
for layer in self.layer:
with tf.name_scope(layer.name):
layer.build(None)
@keras_serializable
class TFRobertaPreLayerNormMainLayer(tf.keras.layers.Layer):
@ -765,6 +865,23 @@ class TFRobertaPreLayerNormMainLayer(tf.keras.layers.Layer):
cross_attentions=encoder_outputs.cross_attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
if getattr(self, "pooler", None) is not None:
with tf.name_scope(self.pooler.name):
self.pooler.build(None)
if getattr(self, "embeddings", None) is not None:
with tf.name_scope(self.embeddings.name):
self.embeddings.build(None)
# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaPreTrainedModel with Roberta->RobertaPreLayerNorm,roberta->roberta_prelayernorm
class TFRobertaPreLayerNormPreTrainedModel(TFPreTrainedModel):
@ -948,6 +1065,14 @@ class TFRobertaPreLayerNormModel(TFRobertaPreLayerNormPreTrainedModel):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "roberta_prelayernorm", None) is not None:
with tf.name_scope(self.roberta_prelayernorm.name):
self.roberta_prelayernorm.build(None)
# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaLMHead with Roberta->RobertaPreLayerNorm
class TFRobertaPreLayerNormLMHead(tf.keras.layers.Layer):
@ -968,10 +1093,18 @@ class TFRobertaPreLayerNormLMHead(tf.keras.layers.Layer):
# an output-only bias for each token.
self.decoder = input_embeddings
def build(self, input_shape):
def build(self, input_shape=None):
self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias")
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
if getattr(self, "layer_norm", None) is not None:
with tf.name_scope(self.layer_norm.name):
self.layer_norm.build([None, None, self.config.hidden_size])
def get_output_embeddings(self):
return self.decoder
@ -1085,6 +1218,17 @@ class TFRobertaPreLayerNormForMaskedLM(TFRobertaPreLayerNormPreTrainedModel, TFM
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "roberta_prelayernorm", None) is not None:
with tf.name_scope(self.roberta_prelayernorm.name):
self.roberta_prelayernorm.build(None)
if getattr(self, "lm_head", None) is not None:
with tf.name_scope(self.lm_head.name):
self.lm_head.build(None)
# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaForCausalLM with ROBERTA->ROBERTA_PRELAYERNORM,Roberta->RobertaPreLayerNorm,roberta->roberta_prelayernorm
class TFRobertaPreLayerNormForCausalLM(TFRobertaPreLayerNormPreTrainedModel, TFCausalLanguageModelingLoss):
@ -1214,6 +1358,17 @@ class TFRobertaPreLayerNormForCausalLM(TFRobertaPreLayerNormPreTrainedModel, TFC
cross_attentions=outputs.cross_attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "roberta_prelayernorm", None) is not None:
with tf.name_scope(self.roberta_prelayernorm.name):
self.roberta_prelayernorm.build(None)
if getattr(self, "lm_head", None) is not None:
with tf.name_scope(self.lm_head.name):
self.lm_head.build(None)
# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaClassificationHead with Roberta->RobertaPreLayerNorm
class TFRobertaPreLayerNormClassificationHead(tf.keras.layers.Layer):
@ -1234,6 +1389,7 @@ class TFRobertaPreLayerNormClassificationHead(tf.keras.layers.Layer):
self.out_proj = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="out_proj"
)
self.config = config
def call(self, features, training=False):
x = features[:, 0, :] # take <s> token (equiv. to [CLS])
@ -1243,6 +1399,17 @@ class TFRobertaPreLayerNormClassificationHead(tf.keras.layers.Layer):
x = self.out_proj(x)
return x
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
if getattr(self, "out_proj", None) is not None:
with tf.name_scope(self.out_proj.name):
self.out_proj.build([None, None, self.config.hidden_size])
@add_start_docstrings(
"""
@ -1322,6 +1489,17 @@ class TFRobertaPreLayerNormForSequenceClassification(
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "roberta_prelayernorm", None) is not None:
with tf.name_scope(self.roberta_prelayernorm.name):
self.roberta_prelayernorm.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build(None)
@add_start_docstrings(
"""
@ -1344,6 +1522,7 @@ class TFRobertaPreLayerNormForMultipleChoice(TFRobertaPreLayerNormPreTrainedMode
self.classifier = tf.keras.layers.Dense(
1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(
@ -1415,6 +1594,17 @@ class TFRobertaPreLayerNormForMultipleChoice(TFRobertaPreLayerNormPreTrainedMode
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "roberta_prelayernorm", None) is not None:
with tf.name_scope(self.roberta_prelayernorm.name):
self.roberta_prelayernorm.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.hidden_size])
@add_start_docstrings(
"""
@ -1442,6 +1632,7 @@ class TFRobertaPreLayerNormForTokenClassification(TFRobertaPreLayerNormPreTraine
self.classifier = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(ROBERTA_PRELAYERNORM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ -1499,6 +1690,17 @@ class TFRobertaPreLayerNormForTokenClassification(TFRobertaPreLayerNormPreTraine
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "roberta_prelayernorm", None) is not None:
with tf.name_scope(self.roberta_prelayernorm.name):
self.roberta_prelayernorm.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.hidden_size])
@add_start_docstrings(
"""
@ -1521,6 +1723,7 @@ class TFRobertaPreLayerNormForQuestionAnswering(TFRobertaPreLayerNormPreTrainedM
self.qa_outputs = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(ROBERTA_PRELAYERNORM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ -1591,3 +1794,14 @@ class TFRobertaPreLayerNormForQuestionAnswering(TFRobertaPreLayerNormPreTrainedM
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "roberta_prelayernorm", None) is not None:
with tf.name_scope(self.roberta_prelayernorm.name):
self.roberta_prelayernorm.build(None)
if getattr(self, "qa_outputs", None) is not None:
with tf.name_scope(self.qa_outputs.name):
self.qa_outputs.build([None, None, self.config.hidden_size])

View File

@ -142,7 +142,7 @@ class TFRoFormerEmbeddings(tf.keras.layers.Layer):
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
def build(self, input_shape: tf.TensorShape):
def build(self, input_shape=None):
with tf.name_scope("word_embeddings"):
self.weight = self.add_weight(
name="weight",
@ -157,7 +157,12 @@ class TFRoFormerEmbeddings(tf.keras.layers.Layer):
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.embedding_size])
def call(
self,
@ -218,6 +223,7 @@ class TFRoFormerSelfAttention(tf.keras.layers.Layer):
)
self.dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
self.rotary_value = config.rotary_value
self.config = config
def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
@ -307,6 +313,20 @@ class TFRoFormerSelfAttention(tf.keras.layers.Layer):
return query_layer, key_layer, value_layer
return query_layer, key_layer
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "query", None) is not None:
with tf.name_scope(self.query.name):
self.query.build([None, None, self.config.hidden_size])
if getattr(self, "key", None) is not None:
with tf.name_scope(self.key.name):
self.key.build([None, None, self.config.hidden_size])
if getattr(self, "value", None) is not None:
with tf.name_scope(self.value.name):
self.value.build([None, None, self.config.hidden_size])
# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfOutput with Bert->RoFormer
class TFRoFormerSelfOutput(tf.keras.layers.Layer):
@ -318,6 +338,7 @@ class TFRoFormerSelfOutput(tf.keras.layers.Layer):
)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
self.config = config
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -326,6 +347,17 @@ class TFRoFormerSelfOutput(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
class TFRoFormerAttention(tf.keras.layers.Layer):
def __init__(self, config: RoFormerConfig, **kwargs):
@ -361,6 +393,17 @@ class TFRoFormerAttention(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "self_attention", None) is not None:
with tf.name_scope(self.self_attention.name):
self.self_attention.build(None)
if getattr(self, "dense_output", None) is not None:
with tf.name_scope(self.dense_output.name):
self.dense_output.build(None)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertIntermediate with Bert->RoFormer
class TFRoFormerIntermediate(tf.keras.layers.Layer):
@ -375,6 +418,7 @@ class TFRoFormerIntermediate(tf.keras.layers.Layer):
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
else:
self.intermediate_act_fn = config.hidden_act
self.config = config
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -382,6 +426,14 @@ class TFRoFormerIntermediate(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
# Copied from transformers.models.bert.modeling_tf_bert.TFBertOutput with Bert->RoFormer
class TFRoFormerOutput(tf.keras.layers.Layer):
@ -393,6 +445,7 @@ class TFRoFormerOutput(tf.keras.layers.Layer):
)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
self.config = config
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -401,6 +454,17 @@ class TFRoFormerOutput(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.intermediate_size])
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
class TFRoFormerLayer(tf.keras.layers.Layer):
def __init__(self, config: RoFormerConfig, **kwargs):
@ -436,6 +500,20 @@ class TFRoFormerLayer(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "attention", None) is not None:
with tf.name_scope(self.attention.name):
self.attention.build(None)
if getattr(self, "intermediate", None) is not None:
with tf.name_scope(self.intermediate.name):
self.intermediate.build(None)
if getattr(self, "roformer_output", None) is not None:
with tf.name_scope(self.roformer_output.name):
self.roformer_output.build(None)
class TFRoFormerEncoder(tf.keras.layers.Layer):
def __init__(self, config: RoFormerConfig, **kwargs):
@ -491,6 +569,18 @@ class TFRoFormerEncoder(tf.keras.layers.Layer):
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embed_positions", None) is not None:
with tf.name_scope(self.embed_positions.name):
self.embed_positions.build(None)
if getattr(self, "layer", None) is not None:
for layer in self.layer:
with tf.name_scope(layer.name):
layer.build(None)
class TFRoFormerPredictionHeadTransform(tf.keras.layers.Layer):
def __init__(self, config: RoFormerConfig, **kwargs):
@ -508,6 +598,7 @@ class TFRoFormerPredictionHeadTransform(tf.keras.layers.Layer):
self.transform_act_fn = config.hidden_act
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.config = config
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -516,6 +607,17 @@ class TFRoFormerPredictionHeadTransform(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.embedding_size])
class TFRoFormerLMPredictionHead(tf.keras.layers.Layer):
def __init__(self, config: RoFormerConfig, input_embeddings: tf.keras.layers.Layer, **kwargs):
@ -530,10 +632,15 @@ class TFRoFormerLMPredictionHead(tf.keras.layers.Layer):
# an output-only bias for each token.
self.input_embeddings = input_embeddings
def build(self, input_shape: tf.TensorShape):
def build(self, input_shape=None):
self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias")
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "transform", None) is not None:
with tf.name_scope(self.transform.name):
self.transform.build(None)
def get_output_embeddings(self) -> tf.keras.layers.Layer:
return self.input_embeddings
@ -572,6 +679,14 @@ class TFRoFormerMLMHead(tf.keras.layers.Layer):
return prediction_scores
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "predictions", None) is not None:
with tf.name_scope(self.predictions.name):
self.predictions.build(None)
@keras_serializable
class TFRoFormerMainLayer(tf.keras.layers.Layer):
@ -687,6 +802,20 @@ class TFRoFormerMainLayer(tf.keras.layers.Layer):
attentions=encoder_outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embeddings", None) is not None:
with tf.name_scope(self.embeddings.name):
self.embeddings.build(None)
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
if getattr(self, "embeddings_project", None) is not None:
with tf.name_scope(self.embeddings_project.name):
self.embeddings_project.build([None, None, self.config.embedding_size])
class TFRoFormerPreTrainedModel(TFPreTrainedModel):
"""
@ -834,6 +963,14 @@ class TFRoFormerModel(TFRoFormerPreTrainedModel):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "roformer", None) is not None:
with tf.name_scope(self.roformer.name):
self.roformer.build(None)
@add_start_docstrings("""RoFormer Model with a `language modeling` head on top.""", ROFORMER_START_DOCSTRING)
class TFRoFormerForMaskedLM(TFRoFormerPreTrainedModel, TFMaskedLanguageModelingLoss):
@ -904,6 +1041,17 @@ class TFRoFormerForMaskedLM(TFRoFormerPreTrainedModel, TFMaskedLanguageModelingL
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "roformer", None) is not None:
with tf.name_scope(self.roformer.name):
self.roformer.build(None)
if getattr(self, "mlm", None) is not None:
with tf.name_scope(self.mlm.name):
self.mlm.build(None)
@add_start_docstrings(
"""RoFormer Model with a `language modeling` head on top for CLM fine-tuning.""", ROFORMER_START_DOCSTRING
@ -977,6 +1125,17 @@ class TFRoFormerForCausalLM(TFRoFormerPreTrainedModel, TFCausalLanguageModelingL
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "roformer", None) is not None:
with tf.name_scope(self.roformer.name):
self.roformer.build(None)
if getattr(self, "mlm", None) is not None:
with tf.name_scope(self.mlm.name):
self.mlm.build(None)
class TFRoFormerClassificationHead(tf.keras.layers.Layer):
"""Head for sentence-level classification tasks."""
@ -996,6 +1155,7 @@ class TFRoFormerClassificationHead(tf.keras.layers.Layer):
self.classifier_act_fn = get_tf_activation(config.hidden_act)
else:
self.classifier_act_fn = config.hidden_act
self.config = config
def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = hidden_states[:, 0, :] # take <s> token (equiv. to [CLS])
@ -1007,6 +1167,17 @@ class TFRoFormerClassificationHead(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
if getattr(self, "out_proj", None) is not None:
with tf.name_scope(self.out_proj.name):
self.out_proj.build([None, None, self.config.hidden_size])
@add_start_docstrings(
"""
@ -1075,6 +1246,17 @@ class TFRoFormerForSequenceClassification(TFRoFormerPreTrainedModel, TFSequenceC
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "roformer", None) is not None:
with tf.name_scope(self.roformer.name):
self.roformer.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build(None)
@add_start_docstrings(
"""
@ -1092,6 +1274,7 @@ class TFRoFormerForMultipleChoice(TFRoFormerPreTrainedModel, TFMultipleChoiceLos
self.classifier = tf.keras.layers.Dense(
units=1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(
@ -1167,6 +1350,20 @@ class TFRoFormerForMultipleChoice(TFRoFormerPreTrainedModel, TFMultipleChoiceLos
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "roformer", None) is not None:
with tf.name_scope(self.roformer.name):
self.roformer.build(None)
if getattr(self, "sequence_summary", None) is not None:
with tf.name_scope(self.sequence_summary.name):
self.sequence_summary.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.hidden_size])
@add_start_docstrings(
"""
@ -1186,6 +1383,7 @@ class TFRoFormerForTokenClassification(TFRoFormerPreTrainedModel, TFTokenClassif
self.classifier = tf.keras.layers.Dense(
units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(ROFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ -1238,6 +1436,17 @@ class TFRoFormerForTokenClassification(TFRoFormerPreTrainedModel, TFTokenClassif
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "roformer", None) is not None:
with tf.name_scope(self.roformer.name):
self.roformer.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.hidden_size])
@add_start_docstrings(
"""
@ -1256,6 +1465,7 @@ class TFRoFormerForQuestionAnswering(TFRoFormerPreTrainedModel, TFQuestionAnswer
self.qa_outputs = tf.keras.layers.Dense(
units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(ROFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ -1321,3 +1531,14 @@ class TFRoFormerForQuestionAnswering(TFRoFormerPreTrainedModel, TFQuestionAnswer
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "roformer", None) is not None:
with tf.name_scope(self.roformer.name):
self.roformer.build(None)
if getattr(self, "qa_outputs", None) is not None:
with tf.name_scope(self.qa_outputs.name):
self.qa_outputs.build([None, None, self.config.hidden_size])

View File

@ -150,6 +150,14 @@ class TFSamPatchEmbeddings(tf.keras.layers.Layer):
embeddings = self.projection(tf.transpose(pixel_values, perm=[0, 2, 3, 1]))
return embeddings
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "projection", None) is not None:
with tf.name_scope(self.projection.name):
self.projection.build([None, None, None, self.num_channels])
class TFSamMLPBlock(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
@ -157,6 +165,7 @@ class TFSamMLPBlock(tf.keras.layers.Layer):
self.lin1 = tf.keras.layers.Dense(config.mlp_dim, name="lin1")
self.lin2 = tf.keras.layers.Dense(config.hidden_size, name="lin2")
self.act = ACT2FN[config.hidden_act]
self.config = config
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.lin1(hidden_states)
@ -164,6 +173,17 @@ class TFSamMLPBlock(tf.keras.layers.Layer):
hidden_states = self.lin2(hidden_states)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "lin1", None) is not None:
with tf.name_scope(self.lin1.name):
self.lin1.build([None, None, self.config.hidden_size])
if getattr(self, "lin2", None) is not None:
with tf.name_scope(self.lin2.name):
self.lin2.build([None, None, self.config.mlp_dim])
class TFSamLayerNorm(tf.keras.layers.Layer):
r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
@ -257,6 +277,23 @@ class TFSamAttention(tf.keras.layers.Layer):
return out
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "q_proj", None) is not None:
with tf.name_scope(self.q_proj.name):
self.q_proj.build([None, None, self.hidden_size])
if getattr(self, "k_proj", None) is not None:
with tf.name_scope(self.k_proj.name):
self.k_proj.build([None, None, self.hidden_size])
if getattr(self, "v_proj", None) is not None:
with tf.name_scope(self.v_proj.name):
self.v_proj.build([None, None, self.hidden_size])
if getattr(self, "out_proj", None) is not None:
with tf.name_scope(self.out_proj.name):
self.out_proj.build([None, None, self.internal_dim])
class TFSamTwoWayAttentionBlock(tf.keras.layers.Layer):
def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False, **kwargs):
@ -345,6 +382,35 @@ class TFSamTwoWayAttentionBlock(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "self_attn", None) is not None:
with tf.name_scope(self.self_attn.name):
self.self_attn.build(None)
if getattr(self, "layer_norm1", None) is not None:
with tf.name_scope(self.layer_norm1.name):
self.layer_norm1.build([None, None, None, self.hidden_size])
if getattr(self, "cross_attn_token_to_image", None) is not None:
with tf.name_scope(self.cross_attn_token_to_image.name):
self.cross_attn_token_to_image.build(None)
if getattr(self, "layer_norm2", None) is not None:
with tf.name_scope(self.layer_norm2.name):
self.layer_norm2.build([None, None, None, self.hidden_size])
if getattr(self, "mlp", None) is not None:
with tf.name_scope(self.mlp.name):
self.mlp.build(None)
if getattr(self, "layer_norm3", None) is not None:
with tf.name_scope(self.layer_norm3.name):
self.layer_norm3.build([None, None, None, self.hidden_size])
if getattr(self, "layer_norm4", None) is not None:
with tf.name_scope(self.layer_norm4.name):
self.layer_norm4.build([None, None, None, self.hidden_size])
if getattr(self, "cross_attn_image_to_token", None) is not None:
with tf.name_scope(self.cross_attn_image_to_token.name):
self.cross_attn_image_to_token.build(None)
class TFSamTwoWayTransformer(tf.keras.layers.Layer):
def __init__(self, config: SamMaskDecoderConfig, **kwargs):
@ -412,6 +478,20 @@ class TFSamTwoWayTransformer(tf.keras.layers.Layer):
queries = self.layer_norm_final_attn(queries)
return queries, keys, all_attentions
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "final_attn_token_to_image", None) is not None:
with tf.name_scope(self.final_attn_token_to_image.name):
self.final_attn_token_to_image.build(None)
if getattr(self, "layer_norm_final_attn", None) is not None:
with tf.name_scope(self.layer_norm_final_attn.name):
self.layer_norm_final_attn.build([None, None, None, self.config.hidden_size])
for layer in self.layers:
with tf.name_scope(layer.name):
layer.build(None)
class TFSamFeedForward(tf.keras.layers.Layer):
def __init__(
@ -427,6 +507,8 @@ class TFSamFeedForward(tf.keras.layers.Layer):
for i in range(num_layers - 2)
]
self.sigmoid_output = sigmoid_output
self.hidden_dim = hidden_dim
self.input_dim = input_dim
def call(self, hidden_states):
hidden_states = self.proj_in(hidden_states)
@ -439,6 +521,21 @@ class TFSamFeedForward(tf.keras.layers.Layer):
hidden_states = tf.sigmoid(hidden_states)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "proj_in", None) is not None:
with tf.name_scope(self.proj_in.name):
self.proj_in.build([None, None, self.input_dim])
if getattr(self, "proj_out", None) is not None:
with tf.name_scope(self.proj_out.name):
self.proj_out.build([None, None, self.hidden_dim])
if getattr(self, "layers", None) is not None:
for layer in self.layers:
with tf.name_scope(layer.name):
layer.build([None, None, self.hidden_dim])
class TFSamMaskDecoder(tf.keras.layers.Layer):
def __init__(self, config: SamMaskDecoderConfig, **kwargs):
@ -483,12 +580,30 @@ class TFSamMaskDecoder(tf.keras.layers.Layer):
name="iou_prediction_head",
)
def build(self, input_shape):
def build(self, input_shape=None):
if self.built:
return
self.built = True
self.iou_token = self.add_weight(shape=(1, self.hidden_size), name="iou_token.weight", trainable=True)
self.mask_tokens = self.add_weight(
shape=(self.num_mask_tokens, self.hidden_size), name="mask_tokens.weight", trainable=True
)
super().build(input_shape)
if getattr(self, "transformer", None) is not None:
with tf.name_scope(self.transformer.name):
self.transformer.build(None)
if getattr(self, "upscale_conv1", None) is not None:
with tf.name_scope(self.upscale_conv1.name):
self.upscale_conv1.build([None, self.hidden_size, None, None])
if getattr(self, "upscale_conv2", None) is not None:
with tf.name_scope(self.upscale_conv2.name):
self.upscale_conv2.build([None, self.hidden_size // 4, None, None])
if getattr(self, "upscale_layer_norm", None) is not None:
with tf.name_scope(self.upscale_layer_norm.name):
self.upscale_layer_norm.build(None)
if getattr(self, "iou_prediction_head", None) is not None:
with tf.name_scope(self.iou_prediction_head.name):
self.iou_prediction_head.build(None)
def call(
self,
@ -615,6 +730,7 @@ class TFSamMaskEmbedding(tf.keras.layers.Layer):
self.conv3 = tf.keras.layers.Conv2D(config.hidden_size, kernel_size=1, name="conv3")
self.layer_norm1 = TFSamLayerNorm(self.mask_input_channels, config.layer_norm_eps, name="layer_norm1")
self.layer_norm2 = TFSamLayerNorm(self.mask_input_channels * 4, config.layer_norm_eps, name="layer_norm2")
self.config = config
def call(self, masks):
masks = tf.transpose(masks, perm=(0, 2, 3, 1)) # Convert to channels-last
@ -629,24 +745,21 @@ class TFSamMaskEmbedding(tf.keras.layers.Layer):
dense_embeddings = tf.transpose(dense_embeddings, perm=(0, 3, 1, 2)) # Convert back to channels-first
return dense_embeddings
def build(self, input_shape):
def build(self, input_shape=None):
# This class needs an explicit build method because it isn't called with the standard dummy inputs
conv1_shape = [None, None, None, 1]
conv2_shape = [None, None, None, self.mask_input_channels]
conv3_shape = [None, None, None, self.mask_input_channels * 4]
layer_norm1_shape = [None, None, None, self.mask_input_channels]
layer_norm2_shape = [None, None, None, self.mask_input_channels * 4]
if self.built:
return
self.built = True
with tf.name_scope("conv1"):
self.conv1.build(conv1_shape)
self.conv1.build([None, None, None, 1])
with tf.name_scope("conv2"):
self.conv2.build(conv2_shape)
self.conv2.build([None, None, None, self.mask_input_channels])
with tf.name_scope("conv3"):
self.conv3.build(conv3_shape)
self.conv3.build([None, None, None, self.mask_input_channels * 4])
with tf.name_scope("layer_norm1"):
self.layer_norm1.build(layer_norm1_shape)
self.layer_norm1.build([None, None, None, self.mask_input_channels])
with tf.name_scope("layer_norm2"):
self.layer_norm2.build(layer_norm2_shape)
super().build(input_shape)
self.layer_norm2.build([None, None, None, self.mask_input_channels * 4])
class TFSamPromptEncoder(tf.keras.layers.Layer):
@ -664,7 +777,7 @@ class TFSamPromptEncoder(tf.keras.layers.Layer):
self.not_a_point_embed = None
self.config = config
def build(self, input_shape):
def build(self, input_shape=None):
self.no_mask_embed = self.add_weight(
name="no_mask_embed.weight",
shape=(1, self.hidden_size),
@ -691,7 +804,13 @@ class TFSamPromptEncoder(tf.keras.layers.Layer):
self.mask_embed.build(
(None, self.config.mask_input_channels, self.config.image_size, self.config.image_size)
)
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "mask_embed", None) is not None:
with tf.name_scope(self.mask_embed.name):
self.mask_embed.build(None)
def _embed_points(self, points: tf.Tensor, labels: tf.Tensor, pad: bool) -> tf.Tensor:
"""Embeds point prompts."""
@ -812,7 +931,7 @@ class TFSamVisionAttention(tf.keras.layers.Layer):
raise ValueError("Input size must be provided if using relative positional encoding.")
self.config = config
def build(self, input_shape):
def build(self, input_shape=None):
if self.input_size is not None:
# initialize relative positional embeddings
self.rel_pos_h = self.add_weight(
@ -821,7 +940,16 @@ class TFSamVisionAttention(tf.keras.layers.Layer):
self.rel_pos_w = self.add_weight(
shape=(2 * self.input_size[1] - 1, self.head_dim), initializer="zeros", name="rel_pos_w"
)
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "qkv", None) is not None:
with tf.name_scope(self.qkv.name):
self.qkv.build([None, None, self.config.hidden_size])
if getattr(self, "proj", None) is not None:
with tf.name_scope(self.proj.name):
self.proj.build([None, None, self.config.hidden_size])
def get_rel_pos(self, q_size: int, k_size: int, rel_pos: tf.Tensor) -> tf.Tensor:
"""
@ -949,6 +1077,7 @@ class TFSamVisionLayer(tf.keras.layers.Layer):
self.layer_norm2 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm2")
self.mlp = TFSamMLPBlock(config, name="mlp")
self.window_size = window_size
self.config = config
def window_partition(self, hidden_states: tf.Tensor, window_size: int) -> Tuple[tf.Tensor, Tuple[int, int]]:
batch_size, height, width, channel = shape_list(hidden_states)
@ -1016,6 +1145,23 @@ class TFSamVisionLayer(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "layer_norm1", None) is not None:
with tf.name_scope(self.layer_norm1.name):
self.layer_norm1.build([None, None, None, self.config.hidden_size])
if getattr(self, "attn", None) is not None:
with tf.name_scope(self.attn.name):
self.attn.build(None)
if getattr(self, "layer_norm2", None) is not None:
with tf.name_scope(self.layer_norm2.name):
self.layer_norm2.build([None, None, None, self.config.hidden_size])
if getattr(self, "mlp", None) is not None:
with tf.name_scope(self.mlp.name):
self.mlp.build(None)
class TFSamVisionNeck(tf.keras.layers.Layer):
def __init__(self, config: SamVisionConfig, **kwargs):
@ -1047,6 +1193,23 @@ class TFSamVisionNeck(tf.keras.layers.Layer):
hidden_states = tf.transpose(hidden_states, perm=[0, 3, 1, 2])
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "conv1", None) is not None:
with tf.name_scope(self.conv1.name):
self.conv1.build([None, None, None, self.config.hidden_size])
if getattr(self, "layer_norm1", None) is not None:
with tf.name_scope(self.layer_norm1.name):
self.layer_norm1.build(None)
if getattr(self, "conv2", None) is not None:
with tf.name_scope(self.conv2.name):
self.conv2.build([None, None, None, self.config.output_channels])
if getattr(self, "layer_norm2", None) is not None:
with tf.name_scope(self.layer_norm2.name):
self.layer_norm2.build(None)
class TFSamVisionEncoder(tf.keras.layers.Layer):
def __init__(self, config: SamVisionConfig, **kwargs):
@ -1069,7 +1232,10 @@ class TFSamVisionEncoder(tf.keras.layers.Layer):
self.neck = TFSamVisionNeck(config, name="neck")
def build(self, input_shape):
def build(self, input_shape=None):
if self.built:
return
self.built = True
if self.config.use_abs_pos:
# Initialize absolute positional embedding with pretrain image size.
self.pos_embed = self.add_weight(
@ -1083,7 +1249,16 @@ class TFSamVisionEncoder(tf.keras.layers.Layer):
trainable=True,
name="pos_embed",
)
super().build(input_shape)
if getattr(self, "patch_embed", None) is not None:
with tf.name_scope(self.patch_embed.name):
self.patch_embed.build(None)
if getattr(self, "neck", None) is not None:
with tf.name_scope(self.neck.name):
self.neck.build(None)
for layer in self.layers:
with tf.name_scope(layer.name):
layer.build(None)
def get_input_embeddings(self):
return self.patch_embed
@ -1463,3 +1638,20 @@ class TFSamModel(TFSamPreTrainedModel):
vision_attentions=attns if self.config.output_attentions else None,
mask_decoder_attentions=output.mask_decoder_attentions if self.config.output_attentions else None,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "shared_image_embedding", None) is not None:
with tf.name_scope(self.shared_image_embedding.name):
self.shared_image_embedding.build(None)
if getattr(self, "vision_encoder", None) is not None:
with tf.name_scope(self.vision_encoder.name):
self.vision_encoder.build(None)
if getattr(self, "prompt_encoder", None) is not None:
with tf.name_scope(self.prompt_encoder.name):
self.prompt_encoder.build(None)
if getattr(self, "mask_decoder", None) is not None:
with tf.name_scope(self.mask_decoder.name):
self.mask_decoder.build(None)

View File

@ -79,7 +79,7 @@ class TFSegformerDropPath(tf.keras.layers.Layer):
class TFSegformerOverlapPatchEmbeddings(tf.keras.layers.Layer):
"""Construct the overlapping patch embeddings."""
def __init__(self, patch_size, stride, hidden_size, **kwargs):
def __init__(self, patch_size, stride, num_channels, hidden_size, **kwargs):
super().__init__(**kwargs)
self.padding = tf.keras.layers.ZeroPadding2D(padding=patch_size // 2)
self.proj = tf.keras.layers.Conv2D(
@ -87,6 +87,8 @@ class TFSegformerOverlapPatchEmbeddings(tf.keras.layers.Layer):
)
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-05, name="layer_norm")
self.num_channels = num_channels
self.hidden_size = hidden_size
def call(self, pixel_values: tf.Tensor) -> Tuple[tf.Tensor, int, int]:
embeddings = self.proj(self.padding(pixel_values))
@ -99,6 +101,17 @@ class TFSegformerOverlapPatchEmbeddings(tf.keras.layers.Layer):
embeddings = self.layer_norm(embeddings)
return embeddings, height, width
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "proj", None) is not None:
with tf.name_scope(self.proj.name):
self.proj.build([None, None, None, self.num_channels])
if getattr(self, "layer_norm", None) is not None:
with tf.name_scope(self.layer_norm.name):
self.layer_norm.build([None, None, self.hidden_size])
class TFSegformerEfficientSelfAttention(tf.keras.layers.Layer):
"""SegFormer's efficient self-attention mechanism. Employs the sequence reduction process introduced in the [PvT
@ -196,18 +209,47 @@ class TFSegformerEfficientSelfAttention(tf.keras.layers.Layer):
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "query", None) is not None:
with tf.name_scope(self.query.name):
self.query.build([None, None, self.hidden_size])
if getattr(self, "key", None) is not None:
with tf.name_scope(self.key.name):
self.key.build([None, None, self.hidden_size])
if getattr(self, "value", None) is not None:
with tf.name_scope(self.value.name):
self.value.build([None, None, self.hidden_size])
if getattr(self, "sr", None) is not None:
with tf.name_scope(self.sr.name):
self.sr.build([None, None, None, self.hidden_size])
if getattr(self, "layer_norm", None) is not None:
with tf.name_scope(self.layer_norm.name):
self.layer_norm.build([None, None, self.hidden_size])
class TFSegformerSelfOutput(tf.keras.layers.Layer):
def __init__(self, config: SegformerConfig, hidden_size: int, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(hidden_size, name="dense")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
self.hidden_size = hidden_size
def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states, training=training)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.hidden_size])
class TFSegformerAttention(tf.keras.layers.Layer):
def __init__(
@ -237,6 +279,17 @@ class TFSegformerAttention(tf.keras.layers.Layer):
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "self", None) is not None:
with tf.name_scope(self.self.name):
self.self.build(None)
if getattr(self, "dense_output", None) is not None:
with tf.name_scope(self.dense_output.name):
self.dense_output.build(None)
class TFSegformerDWConv(tf.keras.layers.Layer):
def __init__(self, dim: int = 768, **kwargs):
@ -244,6 +297,7 @@ class TFSegformerDWConv(tf.keras.layers.Layer):
self.depthwise_convolution = tf.keras.layers.Conv2D(
filters=dim, kernel_size=3, strides=1, padding="same", groups=dim, name="dwconv"
)
self.dim = dim
def call(self, hidden_states: tf.Tensor, height: int, width: int) -> tf.Tensor:
batch_size = shape_list(hidden_states)[0]
@ -257,6 +311,14 @@ class TFSegformerDWConv(tf.keras.layers.Layer):
hidden_states = tf.reshape(hidden_states, (batch_size, new_height * new_width, num_channels))
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "depthwise_convolution", None) is not None:
with tf.name_scope(self.depthwise_convolution.name):
self.depthwise_convolution.build([None, None, None, self.dim])
class TFSegformerMixFFN(tf.keras.layers.Layer):
def __init__(
@ -277,6 +339,8 @@ class TFSegformerMixFFN(tf.keras.layers.Layer):
self.intermediate_act_fn = config.hidden_act
self.dense2 = tf.keras.layers.Dense(out_features, name="dense2")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
self.hidden_features = hidden_features
self.in_features = in_features
def call(self, hidden_states: tf.Tensor, height: int, width: int, training: bool = False) -> tf.Tensor:
hidden_states = self.dense1(hidden_states)
@ -287,6 +351,20 @@ class TFSegformerMixFFN(tf.keras.layers.Layer):
hidden_states = self.dropout(hidden_states, training=training)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense1", None) is not None:
with tf.name_scope(self.dense1.name):
self.dense1.build([None, None, self.in_features])
if getattr(self, "depthwise_convolution", None) is not None:
with tf.name_scope(self.depthwise_convolution.name):
self.depthwise_convolution.build(None)
if getattr(self, "dense2", None) is not None:
with tf.name_scope(self.dense2.name):
self.dense2.build([None, None, self.hidden_features])
class TFSegformerLayer(tf.keras.layers.Layer):
"""This corresponds to the Block class in the original implementation."""
@ -314,6 +392,7 @@ class TFSegformerLayer(tf.keras.layers.Layer):
self.layer_norm_2 = tf.keras.layers.LayerNormalization(epsilon=1e-05, name="layer_norm_2")
mlp_hidden_size = int(hidden_size * mlp_ratio)
self.mlp = TFSegformerMixFFN(config, in_features=hidden_size, hidden_features=mlp_hidden_size, name="mlp")
self.hidden_size = hidden_size
def call(
self,
@ -347,6 +426,23 @@ class TFSegformerLayer(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "layer_norm_1", None) is not None:
with tf.name_scope(self.layer_norm_1.name):
self.layer_norm_1.build([None, None, self.hidden_size])
if getattr(self, "attention", None) is not None:
with tf.name_scope(self.attention.name):
self.attention.build(None)
if getattr(self, "layer_norm_2", None) is not None:
with tf.name_scope(self.layer_norm_2.name):
self.layer_norm_2.build([None, None, self.hidden_size])
if getattr(self, "mlp", None) is not None:
with tf.name_scope(self.mlp.name):
self.mlp.build(None)
class TFSegformerEncoder(tf.keras.layers.Layer):
def __init__(self, config: SegformerConfig, **kwargs):
@ -363,6 +459,7 @@ class TFSegformerEncoder(tf.keras.layers.Layer):
TFSegformerOverlapPatchEmbeddings(
patch_size=config.patch_sizes[i],
stride=config.strides[i],
num_channels=config.num_channels if i == 0 else config.hidden_sizes[i - 1],
hidden_size=config.hidden_sizes[i],
name=f"patch_embeddings.{i}",
)
@ -449,6 +546,24 @@ class TFSegformerEncoder(tf.keras.layers.Layer):
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "layer_norms", None) is not None:
for layer, shape in zip(self.layer_norms, self.config.hidden_sizes):
with tf.name_scope(layer.name):
layer.build([None, None, shape])
if getattr(self, "block", None) is not None:
for block in self.block:
for layer in block:
with tf.name_scope(layer.name):
layer.build(None)
if getattr(self, "embeddings", None) is not None:
for layer in self.embeddings:
with tf.name_scope(layer.name):
layer.build(None)
@keras_serializable
class TFSegformerMainLayer(tf.keras.layers.Layer):
@ -509,6 +624,14 @@ class TFSegformerMainLayer(tf.keras.layers.Layer):
attentions=encoder_outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
class TFSegformerPreTrainedModel(TFPreTrainedModel):
"""
@ -605,6 +728,14 @@ class TFSegformerModel(TFSegformerPreTrainedModel):
)
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "segformer", None) is not None:
with tf.name_scope(self.segformer.name):
self.segformer.build(None)
@add_start_docstrings(
"""
@ -622,6 +753,7 @@ class TFSegformerForImageClassification(TFSegformerPreTrainedModel, TFSequenceCl
# Classifier head
self.classifier = tf.keras.layers.Dense(config.num_labels, name="classifier")
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(SEGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ -668,15 +800,27 @@ class TFSegformerForImageClassification(TFSegformerPreTrainedModel, TFSequenceCl
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "segformer", None) is not None:
with tf.name_scope(self.segformer.name):
self.segformer.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.hidden_sizes[-1]])
class TFSegformerMLP(tf.keras.layers.Layer):
"""
Linear Embedding.
"""
def __init__(self, config: SegformerConfig, **kwargs):
def __init__(self, input_dim: int, config: SegformerConfig, **kwargs):
super().__init__(**kwargs)
self.proj = tf.keras.layers.Dense(config.decoder_hidden_size, name="proj")
self.input_dim = input_dim
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
height = shape_list(hidden_states)[1]
@ -686,6 +830,14 @@ class TFSegformerMLP(tf.keras.layers.Layer):
hidden_states = self.proj(hidden_states)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "proj", None) is not None:
with tf.name_scope(self.proj.name):
self.proj.build([None, None, self.input_dim])
class TFSegformerDecodeHead(TFSegformerPreTrainedModel):
def __init__(self, config: SegformerConfig, **kwargs):
@ -693,7 +845,7 @@ class TFSegformerDecodeHead(TFSegformerPreTrainedModel):
# linear layers which will unify the channel dimension of each of the encoder blocks to the same config.decoder_hidden_size
mlps = []
for i in range(config.num_encoder_blocks):
mlp = TFSegformerMLP(config, name=f"linear_c.{i}")
mlp = TFSegformerMLP(config=config, input_dim=config.hidden_sizes[i], name=f"linear_c.{i}")
mlps.append(mlp)
self.mlps = mlps
@ -741,6 +893,26 @@ class TFSegformerDecodeHead(TFSegformerPreTrainedModel):
return logits
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "linear_fuse", None) is not None:
with tf.name_scope(self.linear_fuse.name):
self.linear_fuse.build(
[None, None, None, self.config.decoder_hidden_size * self.config.num_encoder_blocks]
)
if getattr(self, "batch_norm", None) is not None:
with tf.name_scope(self.batch_norm.name):
self.batch_norm.build([None, None, None, self.config.decoder_hidden_size])
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, None, self.config.decoder_hidden_size])
if getattr(self, "mlps", None) is not None:
for layer in self.mlps:
with tf.name_scope(layer.name):
layer.build(None)
@add_start_docstrings(
"""SegFormer Model transformer with an all-MLP decode head on top e.g. for ADE20k, CityScapes.""",
@ -851,3 +1023,14 @@ class TFSegformerForSemanticSegmentation(TFSegformerPreTrainedModel):
hidden_states=outputs.hidden_states if output_hidden_states else None,
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "segformer", None) is not None:
with tf.name_scope(self.segformer.name):
self.segformer.build(None)
if getattr(self, "decode_head", None) is not None:
with tf.name_scope(self.decode_head.name):
self.decode_head.build(None)

View File

@ -166,6 +166,15 @@ class TFConv1dSubsampler(tf.keras.layers.Layer):
hidden_states = glu(hidden_states, axis=2) # GLU over the Channel dimension
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "conv_layers", None) is not None:
for i, layer in enumerate(self.conv_layers):
with tf.name_scope(layer.name):
layer.build([None, None, self.in_channels] if i == 0 else [None, None, self.mid_channels // 2])
class TFSpeech2TextSinusoidalPositionalEmbedding(tf.keras.layers.Layer):
"""This module produces sinusoidal positional embeddings of any length."""
@ -379,6 +388,23 @@ class TFSpeech2TextAttention(tf.keras.layers.Layer):
return attn_output, attn_weights, past_key_value
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "k_proj", None) is not None:
with tf.name_scope(self.k_proj.name):
self.k_proj.build([None, None, self.embed_dim])
if getattr(self, "q_proj", None) is not None:
with tf.name_scope(self.q_proj.name):
self.q_proj.build([None, None, self.embed_dim])
if getattr(self, "v_proj", None) is not None:
with tf.name_scope(self.v_proj.name):
self.v_proj.build([None, None, self.embed_dim])
if getattr(self, "out_proj", None) is not None:
with tf.name_scope(self.out_proj.name):
self.out_proj.build([None, None, self.embed_dim])
class TFSpeech2TextEncoderLayer(tf.keras.layers.Layer):
def __init__(self, config: Speech2TextConfig, **kwargs):
@ -394,6 +420,7 @@ class TFSpeech2TextEncoderLayer(tf.keras.layers.Layer):
self.fc1 = tf.keras.layers.Dense(config.encoder_ffn_dim, name="fc1")
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
self.config = config
def call(
self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, layer_head_mask: tf.Tensor, training: bool = False
@ -434,6 +461,26 @@ class TFSpeech2TextEncoderLayer(tf.keras.layers.Layer):
return hidden_states, self_attn_weights
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "self_attn", None) is not None:
with tf.name_scope(self.self_attn.name):
self.self_attn.build(None)
if getattr(self, "self_attn_layer_norm", None) is not None:
with tf.name_scope(self.self_attn_layer_norm.name):
self.self_attn_layer_norm.build([None, None, self.embed_dim])
if getattr(self, "fc1", None) is not None:
with tf.name_scope(self.fc1.name):
self.fc1.build([None, None, self.embed_dim])
if getattr(self, "fc2", None) is not None:
with tf.name_scope(self.fc2.name):
self.fc2.build([None, None, self.config.encoder_ffn_dim])
if getattr(self, "final_layer_norm", None) is not None:
with tf.name_scope(self.final_layer_norm.name):
self.final_layer_norm.build([None, None, self.embed_dim])
class TFSpeech2TextDecoderLayer(tf.keras.layers.Layer):
def __init__(self, config: Speech2TextConfig, **kwargs):
@ -463,6 +510,7 @@ class TFSpeech2TextDecoderLayer(tf.keras.layers.Layer):
self.fc1 = tf.keras.layers.Dense(config.decoder_ffn_dim, name="fc1")
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
self.config = config
def call(
self,
@ -546,6 +594,32 @@ class TFSpeech2TextDecoderLayer(tf.keras.layers.Layer):
present_key_value,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "self_attn", None) is not None:
with tf.name_scope(self.self_attn.name):
self.self_attn.build(None)
if getattr(self, "self_attn_layer_norm", None) is not None:
with tf.name_scope(self.self_attn_layer_norm.name):
self.self_attn_layer_norm.build([None, None, self.embed_dim])
if getattr(self, "encoder_attn", None) is not None:
with tf.name_scope(self.encoder_attn.name):
self.encoder_attn.build(None)
if getattr(self, "encoder_attn_layer_norm", None) is not None:
with tf.name_scope(self.encoder_attn_layer_norm.name):
self.encoder_attn_layer_norm.build([None, None, self.embed_dim])
if getattr(self, "fc1", None) is not None:
with tf.name_scope(self.fc1.name):
self.fc1.build([None, None, self.embed_dim])
if getattr(self, "fc2", None) is not None:
with tf.name_scope(self.fc2.name):
self.fc2.build([None, None, self.config.decoder_ffn_dim])
if getattr(self, "final_layer_norm", None) is not None:
with tf.name_scope(self.final_layer_norm.name):
self.final_layer_norm.build([None, None, self.embed_dim])
class TFSpeech2TextPreTrainedModel(TFPreTrainedModel):
config_class = Speech2TextConfig
@ -870,6 +944,24 @@ class TFSpeech2TextEncoder(tf.keras.layers.Layer):
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "conv", None) is not None:
with tf.name_scope(self.conv.name):
self.conv.build(None)
if getattr(self, "embed_positions", None) is not None:
with tf.name_scope(self.embed_positions.name):
self.embed_positions.build(None)
if getattr(self, "layer_norm", None) is not None:
with tf.name_scope(self.layer_norm.name):
self.layer_norm.build([None, None, self.config.d_model])
if getattr(self, "layers", None) is not None:
for layer in self.layers:
with tf.name_scope(layer.name):
layer.build(None)
@keras_serializable
class TFSpeech2TextDecoder(tf.keras.layers.Layer):
@ -1092,6 +1184,24 @@ class TFSpeech2TextDecoder(tf.keras.layers.Layer):
cross_attentions=all_cross_attns,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embed_tokens", None) is not None:
with tf.name_scope(self.embed_tokens.name):
self.embed_tokens.build(None)
if getattr(self, "embed_positions", None) is not None:
with tf.name_scope(self.embed_positions.name):
self.embed_positions.build(None)
if getattr(self, "layer_norm", None) is not None:
with tf.name_scope(self.layer_norm.name):
self.layer_norm.build([None, None, self.config.d_model])
if getattr(self, "layers", None) is not None:
for layer in self.layers:
with tf.name_scope(layer.name):
layer.build(None)
@keras_serializable
class TFSpeech2TextMainLayer(tf.keras.layers.Layer):
@ -1197,6 +1307,17 @@ class TFSpeech2TextMainLayer(tf.keras.layers.Layer):
encoder_attentions=encoder_outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
if getattr(self, "decoder", None) is not None:
with tf.name_scope(self.decoder.name):
self.decoder.build(None)
@add_start_docstrings(
"The bare Speech2Text Model outputting raw hidden-states without any specific head on top.",
@ -1279,6 +1400,14 @@ class TFSpeech2TextModel(TFSpeech2TextPreTrainedModel):
encoder_attentions=enc_attns,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "model", None) is not None:
with tf.name_scope(self.model.name):
self.model.build(None)
@add_start_docstrings(
"The Speech2Text Model with a language modeling head. Can be used for summarization.",
@ -1291,6 +1420,7 @@ class TFSpeech2TextForConditionalGeneration(TFSpeech2TextPreTrainedModel, TFCaus
self.lm_head = tf.keras.layers.Dense(self.config.vocab_size, use_bias=False, name="lm_head")
# TODO (Joao): investigate why Speech2Text has numerical issues in XLA generate
self.supports_xla_generation = False
self.config = config
def get_encoder(self):
return self.model.encoder
@ -1461,6 +1591,17 @@ class TFSpeech2TextForConditionalGeneration(TFSpeech2TextPreTrainedModel, TFCaus
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
}
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "model", None) is not None:
with tf.name_scope(self.model.name):
self.model.build(None)
if getattr(self, "lm_head", None) is not None:
with tf.name_scope(self.lm_head.name):
self.lm_head.build([None, None, self.config.d_model])
def tf_to_pt_weight_rename(self, tf_weight):
if tf_weight == "lm_head.weight":
return tf_weight, "model.decoder.embed_tokens.weight"

View File

@ -283,6 +283,7 @@ class TFSwinEmbeddings(tf.keras.layers.Layer):
self.norm = tf.keras.layers.LayerNormalization(name="norm", epsilon=1e-5)
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob, name="dropout")
self.config = config
def build(self, input_shape: tf.TensorShape) -> None:
if self.use_mask_token:
@ -296,7 +297,19 @@ class TFSwinEmbeddings(tf.keras.layers.Layer):
)
else:
self.position_embeddings = None
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "patch_embeddings", None) is not None:
with tf.name_scope(self.patch_embeddings.name):
self.patch_embeddings.build(None)
if getattr(self, "norm", None) is not None:
with tf.name_scope(self.norm.name):
self.norm.build([None, None, self.config.embed_dim])
if getattr(self, "dropout", None) is not None:
with tf.name_scope(self.dropout.name):
self.dropout.build(None)
def call(
self, pixel_values: tf.Tensor, bool_masked_pos: bool = None, training: bool = False
@ -381,6 +394,14 @@ class TFSwinPatchEmbeddings(tf.keras.layers.Layer):
embeddings = tf.transpose(embeddings, (0, 2, 1))
return embeddings, output_dimensions
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "projection", None) is not None:
with tf.name_scope(self.projection.name):
self.projection.build([None, None, None, self.num_channels])
class TFSwinPatchMerging(tf.keras.layers.Layer):
"""
@ -443,6 +464,17 @@ class TFSwinPatchMerging(tf.keras.layers.Layer):
return input_feature
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "reduction", None) is not None:
with tf.name_scope(self.reduction.name):
self.reduction.build([None, None, 4 * self.dim])
if getattr(self, "norm", None) is not None:
with tf.name_scope(self.norm.name):
self.norm.build([None, None, 4 * self.dim])
class TFSwinDropPath(tf.keras.layers.Layer):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
@ -521,7 +553,19 @@ class TFSwinSelfAttention(tf.keras.layers.Layer):
relative_coords = tf.stack([stack_0, stack_1], axis=2)
self.relative_position_index.assign(tf.cast(tf.reduce_sum(relative_coords, axis=-1), tf.int32))
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "query", None) is not None:
with tf.name_scope(self.query.name):
self.query.build([None, None, self.all_head_size])
if getattr(self, "key", None) is not None:
with tf.name_scope(self.key.name):
self.key.build([None, None, self.all_head_size])
if getattr(self, "value", None) is not None:
with tf.name_scope(self.value.name):
self.value.build([None, None, self.all_head_size])
def transpose_for_scores(self, x: tf.Tensor) -> tf.Tensor:
new_x_shape = shape_list(x)[:-1] + [self.num_attention_heads, self.attention_head_size]
@ -597,12 +641,24 @@ class TFSwinSelfOutput(tf.keras.layers.Layer):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(dim, name="dense")
self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob, name="dropout")
self.dim = dim
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states, training=training)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.dim])
if getattr(self, "dropout", None) is not None:
with tf.name_scope(self.dropout.name):
self.dropout.build(None)
class TFSwinAttention(tf.keras.layers.Layer):
def __init__(self, config: SwinConfig, dim: int, num_heads: int, **kwargs) -> None:
@ -631,6 +687,17 @@ class TFSwinAttention(tf.keras.layers.Layer):
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "self", None) is not None:
with tf.name_scope(self.self.name):
self.self.build(None)
if getattr(self, "self_output", None) is not None:
with tf.name_scope(self.self_output.name):
self.self_output.build(None)
class TFSwinIntermediate(tf.keras.layers.Layer):
def __init__(self, config: SwinConfig, dim: int, **kwargs) -> None:
@ -640,24 +707,43 @@ class TFSwinIntermediate(tf.keras.layers.Layer):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
self.dim = dim
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.dim])
class TFSwinOutput(tf.keras.layers.Layer):
def __init__(self, config: SwinConfig, dim: int, **kwargs) -> None:
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(dim, name="dense")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob, "dropout")
self.config = config
self.dim = dim
def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states, training=training)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, int(self.config.mlp_ratio * self.dim)])
class TFSwinLayer(tf.keras.layers.Layer):
def __init__(
@ -684,6 +770,7 @@ class TFSwinLayer(tf.keras.layers.Layer):
)
self.intermediate = TFSwinIntermediate(config, dim, name="intermediate")
self.swin_output = TFSwinOutput(config, dim, name="output")
self.dim = dim
def get_attn_mask(self, height: int, width: int, window_size: int, shift_size: int) -> tf.Tensor | None:
img_mask = tf.zeros((height, width))
@ -789,6 +876,29 @@ class TFSwinLayer(tf.keras.layers.Layer):
layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,)
return layer_outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "layernorm_before", None) is not None:
with tf.name_scope(self.layernorm_before.name):
self.layernorm_before.build([None, None, self.dim])
if getattr(self, "attention", None) is not None:
with tf.name_scope(self.attention.name):
self.attention.build(None)
if getattr(self, "drop_path", None) is not None:
with tf.name_scope(self.drop_path.name):
self.drop_path.build(None)
if getattr(self, "layernorm_after", None) is not None:
with tf.name_scope(self.layernorm_after.name):
self.layernorm_after.build([None, None, self.dim])
if getattr(self, "intermediate", None) is not None:
with tf.name_scope(self.intermediate.name):
self.intermediate.build(None)
if getattr(self, "swin_output", None) is not None:
with tf.name_scope(self.swin_output.name):
self.swin_output.build(None)
class TFSwinStage(tf.keras.layers.Layer):
def __init__(
@ -861,6 +971,18 @@ class TFSwinStage(tf.keras.layers.Layer):
stage_outputs += layer_outputs[1:]
return stage_outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "downsample", None) is not None:
with tf.name_scope(self.downsample.name):
self.downsample.build(None)
if getattr(self, "blocks", None) is not None:
for layer in self.blocks:
with tf.name_scope(layer.name):
layer.build(None)
class TFSwinEncoder(tf.keras.layers.Layer):
def __init__(self, config: SwinConfig, grid_size: Tuple[int, int], **kwargs):
@ -941,6 +1063,15 @@ class TFSwinEncoder(tf.keras.layers.Layer):
reshaped_hidden_states=all_reshaped_hidden_states,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "layers", None) is not None:
for layer in self.layers:
with tf.name_scope(layer.name):
layer.build(None)
class TFSwinPreTrainedModel(TFPreTrainedModel):
"""
@ -1160,6 +1291,20 @@ class TFSwinMainLayer(tf.keras.layers.Layer):
reshaped_hidden_states=encoder_outputs.reshaped_hidden_states,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embeddings", None) is not None:
with tf.name_scope(self.embeddings.name):
self.embeddings.build(None)
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
if getattr(self, "layernorm", None) is not None:
with tf.name_scope(self.layernorm.name):
self.layernorm.build([None, None, self.num_features])
@add_start_docstrings(
"The bare Swin Model transformer outputting raw hidden-states without any specific head on top.",
@ -1217,6 +1362,14 @@ class TFSwinModel(TFSwinPreTrainedModel):
return swin_outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "swin", None) is not None:
with tf.name_scope(self.swin.name):
self.swin.build(None)
class TFSwinPixelShuffle(tf.keras.layers.Layer):
"""TF layer implementation of torch.nn.PixelShuffle"""
@ -1251,6 +1404,7 @@ class TFSwinDecoder(tf.keras.layers.Layer):
filters=config.encoder_stride**2 * config.num_channels, kernel_size=1, strides=1, name="0"
)
self.pixel_shuffle = TFSwinPixelShuffle(config.encoder_stride, name="1")
self.config = config
def call(self, x: tf.Tensor) -> tf.Tensor:
hidden_states = x
@ -1262,6 +1416,17 @@ class TFSwinDecoder(tf.keras.layers.Layer):
hidden_states = tf.transpose(hidden_states, (0, 3, 1, 2))
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "conv2d", None) is not None:
with tf.name_scope(self.conv2d.name):
self.conv2d.build([None, None, None, self.config.hidden_size])
if getattr(self, "pixel_shuffle", None) is not None:
with tf.name_scope(self.pixel_shuffle.name):
self.pixel_shuffle.build(None)
@add_start_docstrings(
"Swin Model with a decoder on top for masked image modeling, as proposed in"
@ -1372,6 +1537,17 @@ class TFSwinForMaskedImageModeling(TFSwinPreTrainedModel):
reshaped_hidden_states=outputs.reshaped_hidden_states,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "swin", None) is not None:
with tf.name_scope(self.swin.name):
self.swin.build(None)
if getattr(self, "decoder", None) is not None:
with tf.name_scope(self.decoder.name):
self.decoder.build(None)
@add_start_docstrings(
"""
@ -1446,3 +1622,15 @@ class TFSwinForImageClassification(TFSwinPreTrainedModel, TFSequenceClassificati
attentions=outputs.attentions,
reshaped_hidden_states=outputs.reshaped_hidden_states,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "swin", None) is not None:
with tf.name_scope(self.swin.name):
self.swin.build(None)
if getattr(self, "classifier", None) is not None:
if hasattr(self.classifier, "name"):
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.swin.num_features])

View File

@ -45,7 +45,6 @@ from ...modeling_tf_utils import (
)
from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
from ...utils import (
ContextManagers,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
@ -75,16 +74,17 @@ TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST = [
class TFT5LayerNorm(tf.keras.layers.Layer):
def __init__(self, epsilon=1e-6, **kwargs):
def __init__(self, hidden_size, epsilon=1e-6, **kwargs):
"""
Construct a layernorm module in the T5 style No bias and no subtraction of mean.
"""
super().__init__(**kwargs)
self.variance_epsilon = epsilon
self.hidden_size = hidden_size
def build(self, input_shape):
"""Build shared word embedding layer"""
self.weight = self.add_weight("weight", shape=(input_shape[-1],), initializer="ones")
self.weight = self.add_weight("weight", shape=(self.hidden_size,), initializer="ones")
super().build(input_shape)
def call(self, hidden_states):
@ -110,6 +110,7 @@ class TFT5DenseActDense(tf.keras.layers.Layer):
) # Update init weights as in flax
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
self.act = get_tf_activation(config.dense_act_fn)
self.config = config
def call(self, hidden_states, training=False):
hidden_states = self.wi(hidden_states)
@ -118,6 +119,17 @@ class TFT5DenseActDense(tf.keras.layers.Layer):
hidden_states = self.wo(hidden_states)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "wi", None) is not None:
with tf.name_scope(self.wi.name):
self.wi.build([None, None, self.config.d_model])
if getattr(self, "wo", None) is not None:
with tf.name_scope(self.wo.name):
self.wo.build([None, None, self.config.d_ff])
class TFT5DenseGatedActDense(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
@ -139,6 +151,7 @@ class TFT5DenseGatedActDense(tf.keras.layers.Layer):
) # Update init weights as in flax
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
self.act = get_tf_activation(config.dense_act_fn)
self.config = config
def call(self, hidden_states, training=False):
hidden_gelu = self.act(self.wi_0(hidden_states))
@ -148,6 +161,20 @@ class TFT5DenseGatedActDense(tf.keras.layers.Layer):
hidden_states = self.wo(hidden_states)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "wi_0", None) is not None:
with tf.name_scope(self.wi_0.name):
self.wi_0.build([None, None, self.config.d_model])
if getattr(self, "wi_1", None) is not None:
with tf.name_scope(self.wi_1.name):
self.wi_1.build([None, None, self.config.d_model])
if getattr(self, "wo", None) is not None:
with tf.name_scope(self.wo.name):
self.wo.build([None, None, self.config.d_ff])
class TFT5LayerFF(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
@ -157,7 +184,7 @@ class TFT5LayerFF(tf.keras.layers.Layer):
else:
self.DenseReluDense = TFT5DenseActDense(config, name="DenseReluDense")
self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm")
self.layer_norm = TFT5LayerNorm(config.d_model, epsilon=config.layer_norm_epsilon, name="layer_norm")
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
def call(self, hidden_states, training=False):
@ -166,6 +193,17 @@ class TFT5LayerFF(tf.keras.layers.Layer):
hidden_states = hidden_states + self.dropout(dense_output, training=training)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "layer_norm", None) is not None:
with tf.name_scope(self.layer_norm.name):
self.layer_norm.build(None)
if getattr(self, "DenseReluDense", None) is not None:
with tf.name_scope(self.DenseReluDense.name):
self.DenseReluDense.build(None)
class TFT5Attention(tf.keras.layers.Layer):
NEW_ID = itertools.count()
@ -218,7 +256,10 @@ class TFT5Attention(tf.keras.layers.Layer):
self.pruned_heads = set()
def build(self, input_shape):
def build(self, input_shape=None):
if self.built:
return
self.built = True
if self.has_relative_attention_bias:
with tf.name_scope("relative_attention_bias"):
self.relative_attention_bias = self.add_weight(
@ -226,8 +267,18 @@ class TFT5Attention(tf.keras.layers.Layer):
shape=[self.relative_attention_num_buckets, self.n_heads],
initializer=self.relative_attention_bias_initializer, # Add initializer
)
return super().build(input_shape)
if getattr(self, "q", None) is not None:
with tf.name_scope(self.q.name):
self.q.build([None, None, self.d_model])
if getattr(self, "k", None) is not None:
with tf.name_scope(self.k.name):
self.k.build([None, None, self.d_model])
if getattr(self, "v", None) is not None:
with tf.name_scope(self.v.name):
self.v.build([None, None, self.d_model])
if getattr(self, "o", None) is not None:
with tf.name_scope(self.o.name):
self.o.build([None, None, self.inner_dim])
def prune_heads(self, heads):
raise NotImplementedError
@ -439,7 +490,7 @@ class TFT5LayerSelfAttention(tf.keras.layers.Layer):
has_relative_attention_bias=has_relative_attention_bias,
name="SelfAttention",
)
self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm")
self.layer_norm = TFT5LayerNorm(config.d_model, epsilon=config.layer_norm_epsilon, name="layer_norm")
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
def call(
@ -468,6 +519,17 @@ class TFT5LayerSelfAttention(tf.keras.layers.Layer):
outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "SelfAttention", None) is not None:
with tf.name_scope(self.SelfAttention.name):
self.SelfAttention.build(None)
if getattr(self, "layer_norm", None) is not None:
with tf.name_scope(self.layer_norm.name):
self.layer_norm.build(None)
class TFT5LayerCrossAttention(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
@ -477,7 +539,7 @@ class TFT5LayerCrossAttention(tf.keras.layers.Layer):
has_relative_attention_bias=False,
name="EncDecAttention",
)
self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm")
self.layer_norm = TFT5LayerNorm(config.d_model, epsilon=config.layer_norm_epsilon, name="layer_norm")
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
def call(
@ -510,6 +572,17 @@ class TFT5LayerCrossAttention(tf.keras.layers.Layer):
outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "EncDecAttention", None) is not None:
with tf.name_scope(self.EncDecAttention.name):
self.EncDecAttention.build(None)
if getattr(self, "layer_norm", None) is not None:
with tf.name_scope(self.layer_norm.name):
self.layer_norm.build(None)
class TFT5Block(tf.keras.layers.Layer):
def __init__(self, config, has_relative_attention_bias=False, **kwargs):
@ -613,6 +686,15 @@ class TFT5Block(tf.keras.layers.Layer):
outputs = outputs + (present_key_value_state,) + attention_outputs
return outputs # hidden-states, present_key_value_states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
def build(self, input_shape=None):
if self.built:
return
self.built = True
for layer_module in self.layer:
if hasattr(layer_module, "name"):
with tf.name_scope(layer_module.name):
layer_module.build(None)
####################################################
# The full model without a specific pretrained or finetuning head is
@ -640,7 +722,9 @@ class TFT5MainLayer(tf.keras.layers.Layer):
TFT5Block(config, has_relative_attention_bias=bool(i == 0), name=f"block_._{i}")
for i in range(config.num_layers)
]
self.final_layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="final_layer_norm")
self.final_layer_norm = TFT5LayerNorm(
config.d_model, epsilon=config.layer_norm_epsilon, name="final_layer_norm"
)
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
def _prune_heads(self, heads_to_prune):
@ -679,16 +763,8 @@ class TFT5MainLayer(tf.keras.layers.Layer):
if inputs_embeds is None:
assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings"
# if `self.embed_tokens.load_weight_prefix` is set, runs the embedding operation with the correct name
# scope, so that its weights are registered with the desired name for loading/storing. When `tf.name_scope`
# is used with a name ending in `/`, that name replaces the current name scope.
# (embeddings with tf.name_scope: self.embed_tokens.load_weight_prefix/self.embed_tokens.name/embeddings:0)
context = []
if hasattr(self.embed_tokens, "load_weight_prefix"):
context.append(tf.name_scope(self.embed_tokens.load_weight_prefix + "/"))
with ContextManagers(context):
check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim)
inputs_embeds = self.embed_tokens(input_ids)
check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim)
inputs_embeds = self.embed_tokens(input_ids)
batch_size, seq_length = input_shape
@ -846,6 +922,18 @@ class TFT5MainLayer(tf.keras.layers.Layer):
attentions=all_attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "final_layer_norm", None) is not None:
with tf.name_scope(self.final_layer_norm.name):
self.final_layer_norm.build(None)
if getattr(self, "block", None) is not None:
for layer in self.block:
with tf.name_scope(layer.name):
layer.build(None)
####################################################
# TFT5PreTrainedModel is a sub-class of tf.keras.Model
@ -1221,6 +1309,22 @@ class TFT5Model(TFT5PreTrainedModel):
encoder_attentions=encoder_outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
# The shared/tied weights expect to be in the model base namespace
# Adding "/" to the end (not the start!) of a tf.name_scope puts it in the root namespace rather than
# the current one.
with tf.name_scope(self.shared.load_weight_prefix + "/" + self.shared.name + "/"):
self.shared.build(None)
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
if getattr(self, "decoder", None) is not None:
with tf.name_scope(self.decoder.name):
self.decoder.build(None)
@add_start_docstrings("""T5 Model with a `language modeling` head on top.""", T5_START_DOCSTRING)
class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModelingLoss):
@ -1250,6 +1354,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
self.lm_head = tf.keras.layers.Dense(
config.vocab_size, use_bias=False, name="lm_head", kernel_initializer=lm_head_initializer
) # Update init weights as in flax
self.config = config
def get_output_embeddings(self):
if self.config.tie_word_embeddings:
@ -1471,6 +1576,25 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
return self._shift_right(labels)
def build(self, input_shape=None):
if self.built:
return
self.built = True
# The shared/tied weights expect to be in the model base namespace
# Adding "/" to the end (not the start!) of a tf.name_scope puts it in the root namespace rather than
# the current one.
with tf.name_scope(self.shared.load_weight_prefix + "/" + self.shared.name + "/"):
self.shared.build(None)
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
if getattr(self, "decoder", None) is not None:
with tf.name_scope(self.decoder.name):
self.decoder.build(None)
if getattr(self, "lm_head", None) is not None:
with tf.name_scope(self.lm_head.name):
self.lm_head.build([None, None, self.config.d_model])
@add_start_docstrings(
"The bare T5 Model transformer outputting encoder's raw hidden-stateswithout any specific head on top.",
@ -1549,3 +1673,16 @@ class TFT5EncoderModel(TFT5PreTrainedModel):
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
# The shared/tied weights expect to be in the model base namespace
# Adding "/" to the end (not the start!) of a tf.name_scope puts it in the root namespace rather than
# the current one.
with tf.name_scope(self.shared.load_weight_prefix + "/" + self.shared.name + "/"):
self.shared.build(None)
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)

View File

@ -160,7 +160,7 @@ class TFTapasEmbeddings(tf.keras.layers.Layer):
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
def build(self, input_shape: tf.TensorShape):
def build(self, input_shape=None):
with tf.name_scope("word_embeddings"):
self.weight = self.add_weight(
name="weight",
@ -186,7 +186,12 @@ class TFTapasEmbeddings(tf.keras.layers.Layer):
),
)
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
def call(
self,
@ -279,6 +284,7 @@ class TFTapasSelfAttention(tf.keras.layers.Layer):
self.dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
self.is_decoder = config.is_decoder
self.config = config
def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
@ -368,6 +374,20 @@ class TFTapasSelfAttention(tf.keras.layers.Layer):
outputs = outputs + (past_key_value,)
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "query", None) is not None:
with tf.name_scope(self.query.name):
self.query.build([None, None, self.config.hidden_size])
if getattr(self, "key", None) is not None:
with tf.name_scope(self.key.name):
self.key.build([None, None, self.config.hidden_size])
if getattr(self, "value", None) is not None:
with tf.name_scope(self.value.name):
self.value.build([None, None, self.config.hidden_size])
# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfOutput with Bert->Tapas
class TFTapasSelfOutput(tf.keras.layers.Layer):
@ -379,6 +399,7 @@ class TFTapasSelfOutput(tf.keras.layers.Layer):
)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
self.config = config
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -387,6 +408,17 @@ class TFTapasSelfOutput(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
# Copied from transformers.models.bert.modeling_tf_bert.TFBertAttention with Bert->Tapas
class TFTapasAttention(tf.keras.layers.Layer):
@ -428,6 +460,17 @@ class TFTapasAttention(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "self_attention", None) is not None:
with tf.name_scope(self.self_attention.name):
self.self_attention.build(None)
if getattr(self, "dense_output", None) is not None:
with tf.name_scope(self.dense_output.name):
self.dense_output.build(None)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertIntermediate with Bert->Tapas
class TFTapasIntermediate(tf.keras.layers.Layer):
@ -442,6 +485,7 @@ class TFTapasIntermediate(tf.keras.layers.Layer):
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
else:
self.intermediate_act_fn = config.hidden_act
self.config = config
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -449,6 +493,14 @@ class TFTapasIntermediate(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
# Copied from transformers.models.bert.modeling_tf_bert.TFBertOutput with Bert->Tapas
class TFTapasOutput(tf.keras.layers.Layer):
@ -460,6 +512,7 @@ class TFTapasOutput(tf.keras.layers.Layer):
)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
self.config = config
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -468,6 +521,17 @@ class TFTapasOutput(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.intermediate_size])
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLayer with Bert->Tapas
class TFTapasLayer(tf.keras.layers.Layer):
@ -555,6 +619,23 @@ class TFTapasLayer(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "attention", None) is not None:
with tf.name_scope(self.attention.name):
self.attention.build(None)
if getattr(self, "intermediate", None) is not None:
with tf.name_scope(self.intermediate.name):
self.intermediate.build(None)
if getattr(self, "bert_output", None) is not None:
with tf.name_scope(self.bert_output.name):
self.bert_output.build(None)
if getattr(self, "crossattention", None) is not None:
with tf.name_scope(self.crossattention.name):
self.crossattention.build(None)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertEncoder with Bert->Tapas
class TFTapasEncoder(tf.keras.layers.Layer):
@ -625,6 +706,15 @@ class TFTapasEncoder(tf.keras.layers.Layer):
cross_attentions=all_cross_attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "layer", None) is not None:
for layer in self.layer:
with tf.name_scope(layer.name):
layer.build(None)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertPooler with Bert->Tapas
class TFTapasPooler(tf.keras.layers.Layer):
@ -637,6 +727,7 @@ class TFTapasPooler(tf.keras.layers.Layer):
activation="tanh",
name="dense",
)
self.config = config
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
# We "pool" the model by simply taking the hidden state corresponding
@ -646,6 +737,14 @@ class TFTapasPooler(tf.keras.layers.Layer):
return pooled_output
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
# Copied from transformers.models.bert.modeling_tf_bert.TFBertPredictionHeadTransform with Bert->Tapas
class TFTapasPredictionHeadTransform(tf.keras.layers.Layer):
@ -664,6 +763,7 @@ class TFTapasPredictionHeadTransform(tf.keras.layers.Layer):
self.transform_act_fn = config.hidden_act
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.config = config
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -672,6 +772,17 @@ class TFTapasPredictionHeadTransform(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMPredictionHead with Bert->Tapas
class TFTapasLMPredictionHead(tf.keras.layers.Layer):
@ -687,10 +798,15 @@ class TFTapasLMPredictionHead(tf.keras.layers.Layer):
# an output-only bias for each token.
self.input_embeddings = input_embeddings
def build(self, input_shape: tf.TensorShape):
def build(self, input_shape=None):
self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias")
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "transform", None) is not None:
with tf.name_scope(self.transform.name):
self.transform.build(None)
def get_output_embeddings(self) -> tf.keras.layers.Layer:
return self.input_embeddings
@ -729,6 +845,14 @@ class TFTapasMLMHead(tf.keras.layers.Layer):
return prediction_scores
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "predictions", None) is not None:
with tf.name_scope(self.predictions.name):
self.predictions.build(None)
@keras_serializable
class TFTapasMainLayer(tf.keras.layers.Layer):
@ -852,6 +976,20 @@ class TFTapasMainLayer(tf.keras.layers.Layer):
attentions=encoder_outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embeddings", None) is not None:
with tf.name_scope(self.embeddings.name):
self.embeddings.build(None)
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
if getattr(self, "pooler", None) is not None:
with tf.name_scope(self.pooler.name):
self.pooler.build(None)
class TFTapasPreTrainedModel(TFPreTrainedModel):
"""
@ -1033,6 +1171,14 @@ class TFTapasModel(TFTapasPreTrainedModel):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "tapas", None) is not None:
with tf.name_scope(self.tapas.name):
self.tapas.build(None)
@add_start_docstrings("""Tapas Model with a `language modeling` head on top.""", TAPAS_START_DOCSTRING)
class TFTapasForMaskedLM(TFTapasPreTrainedModel, TFMaskedLanguageModelingLoss):
@ -1129,6 +1275,17 @@ class TFTapasForMaskedLM(TFTapasPreTrainedModel, TFMaskedLanguageModelingLoss):
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "tapas", None) is not None:
with tf.name_scope(self.tapas.name):
self.tapas.build(None)
if getattr(self, "lm_head", None) is not None:
with tf.name_scope(self.lm_head.name):
self.lm_head.build(None)
class TFTapasComputeTokenLogits(tf.keras.layers.Layer):
def __init__(self, config: TapasConfig, **kwargs):
@ -1552,6 +1709,23 @@ class TFTapasForQuestionAnswering(TFTapasPreTrainedModel):
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "tapas", None) is not None:
with tf.name_scope(self.tapas.name):
self.tapas.build(None)
if getattr(self, "compute_token_logits", None) is not None:
with tf.name_scope(self.compute_token_logits.name):
self.compute_token_logits.build(None)
if getattr(self, "compute_column_logits", None) is not None:
with tf.name_scope(self.compute_column_logits.name):
self.compute_column_logits.build(None)
if getattr(self, "aggregation_classifier", None) is not None:
with tf.name_scope(self.aggregation_classifier.name):
self.aggregation_classifier.build([None, None, self.config.hidden_size])
@add_start_docstrings(
"""
@ -1570,6 +1744,7 @@ class TFTapasForSequenceClassification(TFTapasPreTrainedModel, TFSequenceClassif
self.classifier = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(TAPAS_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
@ -1654,6 +1829,20 @@ class TFTapasForSequenceClassification(TFTapasPreTrainedModel, TFSequenceClassif
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "tapas", None) is not None:
with tf.name_scope(self.tapas.name):
self.tapas.build(None)
if getattr(self, "dropout", None) is not None:
with tf.name_scope(self.dropout.name):
self.dropout.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.hidden_size])
""" TAPAS utilities."""

View File

@ -684,3 +684,17 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
"Resizing the embedding layers via the TFVisionEncoderDecoderModel directly is not supported. "
"Please use the respective methods of the wrapped objects (model.decoder.resize_token_embeddings(...))"
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "enc_to_dec_proj", None) is not None:
with tf.name_scope(self.enc_to_dec_proj.name):
self.enc_to_dec_proj.build([None, None, self.encoder.config.hidden_size])
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
if getattr(self, "decoder", None) is not None:
with tf.name_scope(self.decoder.name):
self.decoder.build(None)

View File

@ -220,12 +220,26 @@ class TFVisionTextDualEncoderModel(TFPreTrainedModel):
self.visual_projection = Dense(self.projection_dim, use_bias=False, name="visual_projection")
self.text_projection = Dense(self.projection_dim, use_bias=False, name="text_projection")
self.logit_scale = None
self.config = config
def build(self, input_shape=None):
if self.built:
return
self.built = True
# Build in the build() method to make sure the names are right
initializer = tf.keras.initializers.Constant(self.config.logit_scale_init_value)
self.logit_scale = self.add_weight(shape=(1,), initializer=initializer, name="logit_scale")
super().build(input_shape)
if getattr(self, "visual_projection", None) is not None:
with tf.name_scope(self.visual_projection.name):
self.visual_projection.build([None, None, self.vision_embed_dim])
if getattr(self, "text_projection", None) is not None:
with tf.name_scope(self.text_projection.name):
self.text_projection.build([None, None, self.text_embed_dim])
with tf.name_scope(self.vision_model.name):
self.vision_model.build(None)
with tf.name_scope(self.text_model.name):
self.text_model.build(None)
def tf_to_pt_weight_rename(self, tf_weight):
# Matt: The TF and PT weights don't align because our TF base classes have an extra layer compared to PT models

View File

@ -66,7 +66,7 @@ class TFViTEmbeddings(tf.keras.layers.Layer):
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
self.config = config
def build(self, input_shape: tf.TensorShape):
def build(self, input_shape=None):
num_patches = self.patch_embeddings.num_patches
self.cls_token = self.add_weight(
shape=(1, 1, self.config.hidden_size),
@ -81,7 +81,12 @@ class TFViTEmbeddings(tf.keras.layers.Layer):
name="position_embeddings",
)
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "patch_embeddings", None) is not None:
with tf.name_scope(self.patch_embeddings.name):
self.patch_embeddings.build(None)
def interpolate_pos_encoding(self, embeddings, height, width) -> tf.Tensor:
"""
@ -205,6 +210,14 @@ class TFViTPatchEmbeddings(tf.keras.layers.Layer):
return embeddings
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "projection", None) is not None:
with tf.name_scope(self.projection.name):
self.projection.build([None, None, None, self.num_channels])
class TFViTSelfAttention(tf.keras.layers.Layer):
def __init__(self, config: ViTConfig, **kwargs):
@ -231,6 +244,7 @@ class TFViTSelfAttention(tf.keras.layers.Layer):
units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
)
self.dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
self.config = config
def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
@ -280,6 +294,20 @@ class TFViTSelfAttention(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "query", None) is not None:
with tf.name_scope(self.query.name):
self.query.build([None, None, self.config.hidden_size])
if getattr(self, "key", None) is not None:
with tf.name_scope(self.key.name):
self.key.build([None, None, self.config.hidden_size])
if getattr(self, "value", None) is not None:
with tf.name_scope(self.value.name):
self.value.build([None, None, self.config.hidden_size])
class TFViTSelfOutput(tf.keras.layers.Layer):
"""
@ -294,6 +322,7 @@ class TFViTSelfOutput(tf.keras.layers.Layer):
units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
self.config = config
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -301,6 +330,14 @@ class TFViTSelfOutput(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
class TFViTAttention(tf.keras.layers.Layer):
def __init__(self, config: ViTConfig, **kwargs):
@ -329,6 +366,17 @@ class TFViTAttention(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "self_attention", None) is not None:
with tf.name_scope(self.self_attention.name):
self.self_attention.build(None)
if getattr(self, "dense_output", None) is not None:
with tf.name_scope(self.dense_output.name):
self.dense_output.build(None)
class TFViTIntermediate(tf.keras.layers.Layer):
def __init__(self, config: ViTConfig, **kwargs):
@ -342,6 +390,7 @@ class TFViTIntermediate(tf.keras.layers.Layer):
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
else:
self.intermediate_act_fn = config.hidden_act
self.config = config
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -349,6 +398,14 @@ class TFViTIntermediate(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
class TFViTOutput(tf.keras.layers.Layer):
def __init__(self, config: ViTConfig, **kwargs):
@ -358,6 +415,7 @@ class TFViTOutput(tf.keras.layers.Layer):
units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
self.config = config
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -366,6 +424,14 @@ class TFViTOutput(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.intermediate_size])
class TFViTLayer(tf.keras.layers.Layer):
"""This corresponds to the Block class in the timm implementation."""
@ -383,6 +449,7 @@ class TFViTLayer(tf.keras.layers.Layer):
self.layernorm_after = tf.keras.layers.LayerNormalization(
epsilon=config.layer_norm_eps, name="layernorm_after"
)
self.config = config
def call(
self,
@ -416,6 +483,26 @@ class TFViTLayer(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "attention", None) is not None:
with tf.name_scope(self.attention.name):
self.attention.build(None)
if getattr(self, "intermediate", None) is not None:
with tf.name_scope(self.intermediate.name):
self.intermediate.build(None)
if getattr(self, "vit_output", None) is not None:
with tf.name_scope(self.vit_output.name):
self.vit_output.build(None)
if getattr(self, "layernorm_before", None) is not None:
with tf.name_scope(self.layernorm_before.name):
self.layernorm_before.build([None, None, self.config.hidden_size])
if getattr(self, "layernorm_after", None) is not None:
with tf.name_scope(self.layernorm_after.name):
self.layernorm_after.build([None, None, self.config.hidden_size])
class TFViTEncoder(tf.keras.layers.Layer):
def __init__(self, config: ViTConfig, **kwargs):
@ -461,6 +548,15 @@ class TFViTEncoder(tf.keras.layers.Layer):
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "layer", None) is not None:
for layer in self.layer:
with tf.name_scope(layer.name):
layer.build(None)
@keras_serializable
class TFViTMainLayer(tf.keras.layers.Layer):
@ -539,6 +635,23 @@ class TFViTMainLayer(tf.keras.layers.Layer):
attentions=encoder_outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embeddings", None) is not None:
with tf.name_scope(self.embeddings.name):
self.embeddings.build(None)
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
if getattr(self, "layernorm", None) is not None:
with tf.name_scope(self.layernorm.name):
self.layernorm.build([None, None, self.config.hidden_size])
if getattr(self, "pooler", None) is not None:
with tf.name_scope(self.pooler.name):
self.pooler.build(None)
class TFViTPreTrainedModel(TFPreTrainedModel):
"""
@ -665,6 +778,14 @@ class TFViTModel(TFViTPreTrainedModel):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "vit", None) is not None:
with tf.name_scope(self.vit.name):
self.vit.build(None)
class TFViTPooler(tf.keras.layers.Layer):
def __init__(self, config: ViTConfig, **kwargs):
@ -676,6 +797,7 @@ class TFViTPooler(tf.keras.layers.Layer):
activation="tanh",
name="dense",
)
self.config = config
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
# We "pool" the model by simply taking the hidden state corresponding
@ -685,6 +807,14 @@ class TFViTPooler(tf.keras.layers.Layer):
return pooled_output
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
@add_start_docstrings(
"""
@ -714,6 +844,7 @@ class TFViTForImageClassification(TFViTPreTrainedModel, TFSequenceClassification
kernel_initializer=get_initializer(config.initializer_range),
name="classifier",
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
@ -764,3 +895,14 @@ class TFViTForImageClassification(TFViTPreTrainedModel, TFSequenceClassification
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "vit", None) is not None:
with tf.name_scope(self.vit.name):
self.vit.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.hidden_size])

View File

@ -213,7 +213,7 @@ class TFViTMAEEmbeddings(tf.keras.layers.Layer):
self.config = config
def build(self, input_shape: tf.TensorShape):
def build(self, input_shape=None):
self.cls_token = self.add_weight(
shape=(1, 1, self.config.hidden_size),
initializer=tf.random_normal_initializer(stddev=self.config.initializer_range),
@ -233,7 +233,12 @@ class TFViTMAEEmbeddings(tf.keras.layers.Layer):
)[None, ...]
self.position_embeddings.assign(pos_embed)
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "patch_embeddings", None) is not None:
with tf.name_scope(self.patch_embeddings.name):
self.patch_embeddings.build(None)
def random_masking(self, sequence: tf.Tensor, noise: tf.Tensor | None = None):
"""
@ -352,6 +357,14 @@ class TFViTMAEPatchEmbeddings(tf.keras.layers.Layer):
return x
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "projection", None) is not None:
with tf.name_scope(self.projection.name):
self.projection.build([None, None, None, self.num_channels])
# Copied from transformers.models.vit.modeling_tf_vit.TFViTSelfAttention with ViT->ViTMAE
class TFViTMAESelfAttention(tf.keras.layers.Layer):
@ -379,6 +392,7 @@ class TFViTMAESelfAttention(tf.keras.layers.Layer):
units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
)
self.dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
self.config = config
def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
@ -428,6 +442,20 @@ class TFViTMAESelfAttention(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "query", None) is not None:
with tf.name_scope(self.query.name):
self.query.build([None, None, self.config.hidden_size])
if getattr(self, "key", None) is not None:
with tf.name_scope(self.key.name):
self.key.build([None, None, self.config.hidden_size])
if getattr(self, "value", None) is not None:
with tf.name_scope(self.value.name):
self.value.build([None, None, self.config.hidden_size])
# Copied from transformers.models.vit.modeling_tf_vit.TFViTSelfOutput with ViT->ViTMAE
class TFViTMAESelfOutput(tf.keras.layers.Layer):
@ -443,6 +471,7 @@ class TFViTMAESelfOutput(tf.keras.layers.Layer):
units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
self.config = config
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -450,6 +479,14 @@ class TFViTMAESelfOutput(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
# Copied from transformers.models.vit.modeling_tf_vit.TFViTAttention with ViT->ViTMAE
class TFViTMAEAttention(tf.keras.layers.Layer):
@ -479,6 +516,17 @@ class TFViTMAEAttention(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "self_attention", None) is not None:
with tf.name_scope(self.self_attention.name):
self.self_attention.build(None)
if getattr(self, "dense_output", None) is not None:
with tf.name_scope(self.dense_output.name):
self.dense_output.build(None)
# Copied from transformers.models.vit.modeling_tf_vit.TFViTIntermediate with ViT->ViTMAE
class TFViTMAEIntermediate(tf.keras.layers.Layer):
@ -493,6 +541,7 @@ class TFViTMAEIntermediate(tf.keras.layers.Layer):
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
else:
self.intermediate_act_fn = config.hidden_act
self.config = config
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -500,6 +549,14 @@ class TFViTMAEIntermediate(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
# Copied from transformers.models.vit.modeling_tf_vit.TFViTOutput with ViT->ViTMAE
class TFViTMAEOutput(tf.keras.layers.Layer):
@ -510,6 +567,7 @@ class TFViTMAEOutput(tf.keras.layers.Layer):
units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
self.config = config
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -518,6 +576,14 @@ class TFViTMAEOutput(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.intermediate_size])
# Copied from transformers.models.vit.modeling_tf_vit.TFViTLayer with ViT->ViTMAE
class TFViTMAELayer(tf.keras.layers.Layer):
@ -536,6 +602,7 @@ class TFViTMAELayer(tf.keras.layers.Layer):
self.layernorm_after = tf.keras.layers.LayerNormalization(
epsilon=config.layer_norm_eps, name="layernorm_after"
)
self.config = config
def call(
self,
@ -569,6 +636,26 @@ class TFViTMAELayer(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "attention", None) is not None:
with tf.name_scope(self.attention.name):
self.attention.build(None)
if getattr(self, "intermediate", None) is not None:
with tf.name_scope(self.intermediate.name):
self.intermediate.build(None)
if getattr(self, "vit_output", None) is not None:
with tf.name_scope(self.vit_output.name):
self.vit_output.build(None)
if getattr(self, "layernorm_before", None) is not None:
with tf.name_scope(self.layernorm_before.name):
self.layernorm_before.build([None, None, self.config.hidden_size])
if getattr(self, "layernorm_after", None) is not None:
with tf.name_scope(self.layernorm_after.name):
self.layernorm_after.build([None, None, self.config.hidden_size])
# Copied from transformers.models.vit.modeling_tf_vit.TFViTEncoder with ViT->ViTMAE
class TFViTMAEEncoder(tf.keras.layers.Layer):
@ -615,6 +702,15 @@ class TFViTMAEEncoder(tf.keras.layers.Layer):
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "layer", None) is not None:
for layer in self.layer:
with tf.name_scope(layer.name):
layer.build(None)
@keras_serializable
class TFViTMAEMainLayer(tf.keras.layers.Layer):
@ -687,6 +783,20 @@ class TFViTMAEMainLayer(tf.keras.layers.Layer):
attentions=encoder_outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embeddings", None) is not None:
with tf.name_scope(self.embeddings.name):
self.embeddings.build(None)
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
if getattr(self, "layernorm", None) is not None:
with tf.name_scope(self.layernorm.name):
self.layernorm.build([None, None, self.config.hidden_size])
class TFViTMAEPreTrainedModel(TFPreTrainedModel):
"""
@ -829,6 +939,14 @@ class TFViTMAEModel(TFViTMAEPreTrainedModel):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "vit", None) is not None:
with tf.name_scope(self.vit.name):
self.vit.build(None)
class TFViTMAEDecoder(tf.keras.layers.Layer):
def __init__(self, config, num_patches, **kwargs):
@ -853,7 +971,7 @@ class TFViTMAEDecoder(tf.keras.layers.Layer):
self.config = config
self.num_patches = num_patches
def build(self, input_shape: tf.TensorShape):
def build(self, input_shape=None):
self.mask_token = self.add_weight(
shape=(1, 1, self.config.decoder_hidden_size),
initializer=tf.random_normal_initializer(stddev=self.config.initializer_range),
@ -873,7 +991,22 @@ class TFViTMAEDecoder(tf.keras.layers.Layer):
)[None, ...]
self.decoder_pos_embed.assign(decoder_pos_embed)
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "decoder_embed", None) is not None:
with tf.name_scope(self.decoder_embed.name):
self.decoder_embed.build([None, None, self.config.hidden_size])
if getattr(self, "decoder_norm", None) is not None:
with tf.name_scope(self.decoder_norm.name):
self.decoder_norm.build([None, None, self.config.decoder_hidden_size])
if getattr(self, "decoder_pred", None) is not None:
with tf.name_scope(self.decoder_pred.name):
self.decoder_pred.build([None, None, self.config.decoder_hidden_size])
if getattr(self, "decoder_layers", None) is not None:
for layer in self.decoder_layers:
with tf.name_scope(layer.name):
layer.build(None)
def call(
self,
@ -1128,3 +1261,14 @@ class TFViTMAEForPreTraining(TFViTMAEPreTrainedModel):
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "vit", None) is not None:
with tf.name_scope(self.vit.name):
self.vit.build(None)
if getattr(self, "decoder", None) is not None:
with tf.name_scope(self.decoder.name):
self.decoder.build(None)

View File

@ -450,11 +450,6 @@ class TFWav2Vec2WeightNormConv1D(tf.keras.layers.Conv1D):
def build(self, input_shape):
if not self.built:
input_shape = input_shape.as_list()
# If a specific input shape is passed in, we need to modify it to account for padding
# Not necessary if those portions of the shape are None
if input_shape[-2] is not None:
input_shape[-2] += self.explicit_padding * 2
super().build(input_shape)
self.kernel = tf.Variable(tf.transpose(self.kernel), name="weight_v", trainable=True)
@ -502,6 +497,14 @@ class TFWav2Vec2NoLayerNormConvLayer(tf.keras.layers.Layer):
hidden_states = self.activation(hidden_states)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "conv", None) is not None:
with tf.name_scope(self.conv.name):
self.conv.build([None, None, self.in_conv_dim])
class TFWav2Vec2LayerNormConvLayer(tf.keras.layers.Layer):
def __init__(self, config: Wav2Vec2Config, layer_id: int = 0, **kwargs: Any) -> None:
@ -525,6 +528,17 @@ class TFWav2Vec2LayerNormConvLayer(tf.keras.layers.Layer):
hidden_states = self.activation(hidden_states)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "conv", None) is not None:
with tf.name_scope(self.conv.name):
self.conv.build([None, None, self.in_conv_dim])
if getattr(self, "layer_norm", None) is not None:
with tf.name_scope(self.layer_norm.name):
self.layer_norm.build([None, None, self.out_conv_dim])
class TFWav2Vec2GroupNormConvLayer(tf.keras.layers.Layer):
def __init__(self, config: Wav2Vec2Config, layer_id: int = 0, **kwargs: Any) -> None:
@ -550,6 +564,17 @@ class TFWav2Vec2GroupNormConvLayer(tf.keras.layers.Layer):
hidden_states = self.activation(hidden_states)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "conv", None) is not None:
with tf.name_scope(self.conv.name):
self.conv.build([None, None, self.in_conv_dim])
if getattr(self, "layer_norm", None) is not None:
with tf.name_scope(self.layer_norm.name):
self.layer_norm.build([None, None, self.out_conv_dim])
class TFWav2Vec2PositionalConvEmbedding(tf.keras.layers.Layer):
def __init__(self, config: Wav2Vec2Config, **kwargs: Any) -> None:
@ -563,6 +588,7 @@ class TFWav2Vec2PositionalConvEmbedding(tf.keras.layers.Layer):
)
self.padding = TFWav2Vec2SamePadLayer(config.num_conv_pos_embeddings)
self.activation = get_tf_activation(config.feat_extract_activation)
self.config = config
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.conv(hidden_states)
@ -570,6 +596,14 @@ class TFWav2Vec2PositionalConvEmbedding(tf.keras.layers.Layer):
hidden_states = self.activation(hidden_states)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "conv", None) is not None:
with tf.name_scope(self.conv.name):
self.conv.build([None, None, self.config.hidden_size])
class TFWav2Vec2SamePadLayer(tf.keras.layers.Layer):
def __init__(self, num_conv_pos_embeddings, **kwargs):
@ -608,6 +642,15 @@ class TFWav2Vec2FeatureEncoder(tf.keras.layers.Layer):
hidden_states = conv_layer(hidden_states)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "conv_layers", None) is not None:
for conv_layer in self.conv_layers:
with tf.name_scope(conv_layer.name):
conv_layer.build(None)
class TFWav2Vec2FeatureExtractor(TFWav2Vec2FeatureEncoder):
def __init__(self, config, **kwargs):
@ -632,6 +675,7 @@ class TFWav2Vec2FeatureProjection(tf.keras.layers.Layer):
name="projection",
)
self.dropout = tf.keras.layers.Dropout(rate=config.feat_proj_dropout)
self.config = config
def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
norm_hidden_states = self.layer_norm(hidden_states)
@ -639,6 +683,17 @@ class TFWav2Vec2FeatureProjection(tf.keras.layers.Layer):
hidden_states = self.dropout(hidden_states, training=training)
return hidden_states, norm_hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "layer_norm", None) is not None:
with tf.name_scope(self.layer_norm.name):
self.layer_norm.build([None, None, self.config.conv_dim[-1]])
if getattr(self, "projection", None) is not None:
with tf.name_scope(self.projection.name):
self.projection.build([None, None, self.config.conv_dim[-1]])
# Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with TFBart->TFWav2Vec2
class TFWav2Vec2Attention(tf.keras.layers.Layer):
@ -793,6 +848,23 @@ class TFWav2Vec2Attention(tf.keras.layers.Layer):
return attn_output, attn_weights, past_key_value
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "k_proj", None) is not None:
with tf.name_scope(self.k_proj.name):
self.k_proj.build([None, None, self.embed_dim])
if getattr(self, "q_proj", None) is not None:
with tf.name_scope(self.q_proj.name):
self.q_proj.build([None, None, self.embed_dim])
if getattr(self, "v_proj", None) is not None:
with tf.name_scope(self.v_proj.name):
self.v_proj.build([None, None, self.embed_dim])
if getattr(self, "out_proj", None) is not None:
with tf.name_scope(self.out_proj.name):
self.out_proj.build([None, None, self.embed_dim])
class TFWav2Vec2FeedForward(tf.keras.layers.Layer):
def __init__(self, config: Wav2Vec2Config, **kwargs):
@ -815,6 +887,7 @@ class TFWav2Vec2FeedForward(tf.keras.layers.Layer):
name="output_dense",
)
self.output_dropout = tf.keras.layers.Dropout(config.hidden_dropout)
self.config = config
def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.intermediate_dense(hidden_states)
@ -825,6 +898,17 @@ class TFWav2Vec2FeedForward(tf.keras.layers.Layer):
hidden_states = self.output_dropout(hidden_states, training=training)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "intermediate_dense", None) is not None:
with tf.name_scope(self.intermediate_dense.name):
self.intermediate_dense.build([None, None, self.config.hidden_size])
if getattr(self, "output_dense", None) is not None:
with tf.name_scope(self.output_dense.name):
self.output_dense.build([None, None, self.config.intermediate_size])
class TFWav2Vec2EncoderLayer(tf.keras.layers.Layer):
def __init__(self, config: Wav2Vec2Config, **kwargs):
@ -842,6 +926,7 @@ class TFWav2Vec2EncoderLayer(tf.keras.layers.Layer):
self.final_layer_norm = tf.keras.layers.LayerNormalization(
epsilon=config.layer_norm_eps, name="final_layer_norm"
)
self.config = config
def call(
self,
@ -868,6 +953,23 @@ class TFWav2Vec2EncoderLayer(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "attention", None) is not None:
with tf.name_scope(self.attention.name):
self.attention.build(None)
if getattr(self, "layer_norm", None) is not None:
with tf.name_scope(self.layer_norm.name):
self.layer_norm.build([None, None, self.config.hidden_size])
if getattr(self, "feed_forward", None) is not None:
with tf.name_scope(self.feed_forward.name):
self.feed_forward.build(None)
if getattr(self, "final_layer_norm", None) is not None:
with tf.name_scope(self.final_layer_norm.name):
self.final_layer_norm.build([None, None, self.config.hidden_size])
class TFWav2Vec2EncoderLayerStableLayerNorm(tf.keras.layers.Layer):
def __init__(self, config: Wav2Vec2Config, **kwargs):
@ -885,6 +987,7 @@ class TFWav2Vec2EncoderLayerStableLayerNorm(tf.keras.layers.Layer):
self.final_layer_norm = tf.keras.layers.LayerNormalization(
epsilon=config.layer_norm_eps, name="final_layer_norm"
)
self.config = config
def call(
self,
@ -909,6 +1012,23 @@ class TFWav2Vec2EncoderLayerStableLayerNorm(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "attention", None) is not None:
with tf.name_scope(self.attention.name):
self.attention.build(None)
if getattr(self, "layer_norm", None) is not None:
with tf.name_scope(self.layer_norm.name):
self.layer_norm.build([None, None, self.config.hidden_size])
if getattr(self, "feed_forward", None) is not None:
with tf.name_scope(self.feed_forward.name):
self.feed_forward.build(None)
if getattr(self, "final_layer_norm", None) is not None:
with tf.name_scope(self.final_layer_norm.name):
self.final_layer_norm.build([None, None, self.config.hidden_size])
class TFWav2Vec2Encoder(tf.keras.layers.Layer):
def __init__(self, config: Wav2Vec2Config, **kwargs):
@ -974,6 +1094,21 @@ class TFWav2Vec2Encoder(tf.keras.layers.Layer):
attentions=all_self_attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "pos_conv_embed", None) is not None:
with tf.name_scope(self.pos_conv_embed.name):
self.pos_conv_embed.build(None)
if getattr(self, "layer_norm", None) is not None:
with tf.name_scope(self.layer_norm.name):
self.layer_norm.build([None, None, self.config.hidden_size])
if getattr(self, "layer", None) is not None:
for layer in self.layer:
with tf.name_scope(layer.name):
layer.build(None)
class TFWav2Vec2EncoderStableLayerNorm(tf.keras.layers.Layer):
def __init__(self, config: Wav2Vec2Config, **kwargs):
@ -1041,6 +1176,21 @@ class TFWav2Vec2EncoderStableLayerNorm(tf.keras.layers.Layer):
attentions=all_self_attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "pos_conv_embed", None) is not None:
with tf.name_scope(self.pos_conv_embed.name):
self.pos_conv_embed.build(None)
if getattr(self, "layer_norm", None) is not None:
with tf.name_scope(self.layer_norm.name):
self.layer_norm.build([None, None, self.config.hidden_size])
if getattr(self, "layer", None) is not None:
for layer in self.layer:
with tf.name_scope(layer.name):
layer.build(None)
@keras_serializable
class TFWav2Vec2MainLayer(tf.keras.layers.Layer):
@ -1057,12 +1207,23 @@ class TFWav2Vec2MainLayer(tf.keras.layers.Layer):
else:
self.encoder = TFWav2Vec2Encoder(config, name="encoder")
def build(self, input_shape: tf.TensorShape):
def build(self, input_shape=None):
self.masked_spec_embed = self.add_weight(
shape=(self.config.hidden_size,), initializer="uniform", trainable=True, name="masked_spec_embed"
)
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "feature_extractor", None) is not None:
with tf.name_scope(self.feature_extractor.name):
self.feature_extractor.build(None)
if getattr(self, "feature_projection", None) is not None:
with tf.name_scope(self.feature_projection.name):
self.feature_projection.build(None)
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
def _get_feat_extract_output_lengths(self, input_lengths: tf.Tensor):
"""
@ -1419,6 +1580,14 @@ class TFWav2Vec2Model(TFWav2Vec2PreTrainedModel):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "wav2vec2", None) is not None:
with tf.name_scope(self.wav2vec2.name):
self.wav2vec2.build(None)
@add_start_docstrings(
"""TFWav2Vec2 Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
@ -1431,6 +1600,9 @@ class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel):
self.wav2vec2 = TFWav2Vec2MainLayer(config, name="wav2vec2")
self.dropout = tf.keras.layers.Dropout(config.final_dropout)
self.lm_head = tf.keras.layers.Dense(config.vocab_size, name="lm_head")
self.output_hidden_size = (
config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
)
def freeze_feature_extractor(self):
"""
@ -1572,6 +1744,17 @@ class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel):
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "wav2vec2", None) is not None:
with tf.name_scope(self.wav2vec2.name):
self.wav2vec2.build(None)
if getattr(self, "lm_head", None) is not None:
with tf.name_scope(self.lm_head.name):
self.lm_head.build([None, None, self.output_hidden_size])
class TFWav2Vec2ForSequenceClassification(TFWav2Vec2PreTrainedModel):
def __init__(self, config):
@ -1669,3 +1852,17 @@ class TFWav2Vec2ForSequenceClassification(TFWav2Vec2PreTrainedModel):
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "wav2vec2", None) is not None:
with tf.name_scope(self.wav2vec2.name):
self.wav2vec2.build(None)
if getattr(self, "projector", None) is not None:
with tf.name_scope(self.projector.name):
self.projector.build([None, None, self.config.hidden_size])
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.classifier_proj_size])

View File

@ -313,6 +313,23 @@ class TFWhisperAttention(tf.keras.layers.Layer):
return attn_output, attn_weights, past_key_value
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "k_proj", None) is not None:
with tf.name_scope(self.k_proj.name):
self.k_proj.build([None, None, self.embed_dim])
if getattr(self, "v_proj", None) is not None:
with tf.name_scope(self.v_proj.name):
self.v_proj.build([None, None, self.embed_dim])
if getattr(self, "q_proj", None) is not None:
with tf.name_scope(self.q_proj.name):
self.q_proj.build([None, None, self.embed_dim])
if getattr(self, "out_proj", None) is not None:
with tf.name_scope(self.out_proj.name):
self.out_proj.build([None, None, self.embed_dim])
# Copied from transformers.models.speech_to_text.modeling_tf_speech_to_text.TFSpeech2TextEncoderLayer with Speech2Text->Whisper
class TFWhisperEncoderLayer(tf.keras.layers.Layer):
@ -329,6 +346,7 @@ class TFWhisperEncoderLayer(tf.keras.layers.Layer):
self.fc1 = tf.keras.layers.Dense(config.encoder_ffn_dim, name="fc1")
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
self.config = config
def call(
self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, layer_head_mask: tf.Tensor, training: bool = False
@ -369,6 +387,26 @@ class TFWhisperEncoderLayer(tf.keras.layers.Layer):
return hidden_states, self_attn_weights
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "self_attn", None) is not None:
with tf.name_scope(self.self_attn.name):
self.self_attn.build(None)
if getattr(self, "self_attn_layer_norm", None) is not None:
with tf.name_scope(self.self_attn_layer_norm.name):
self.self_attn_layer_norm.build([None, None, self.embed_dim])
if getattr(self, "fc1", None) is not None:
with tf.name_scope(self.fc1.name):
self.fc1.build([None, None, self.embed_dim])
if getattr(self, "fc2", None) is not None:
with tf.name_scope(self.fc2.name):
self.fc2.build([None, None, self.config.encoder_ffn_dim])
if getattr(self, "final_layer_norm", None) is not None:
with tf.name_scope(self.final_layer_norm.name):
self.final_layer_norm.build([None, None, self.embed_dim])
# Copied from transformers.models.speech_to_text.modeling_tf_speech_to_text.TFSpeech2TextDecoderLayer with Speech2Text->Whisper
class TFWhisperDecoderLayer(tf.keras.layers.Layer):
@ -399,6 +437,7 @@ class TFWhisperDecoderLayer(tf.keras.layers.Layer):
self.fc1 = tf.keras.layers.Dense(config.decoder_ffn_dim, name="fc1")
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
self.config = config
def call(
self,
@ -482,6 +521,32 @@ class TFWhisperDecoderLayer(tf.keras.layers.Layer):
present_key_value,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "self_attn", None) is not None:
with tf.name_scope(self.self_attn.name):
self.self_attn.build(None)
if getattr(self, "self_attn_layer_norm", None) is not None:
with tf.name_scope(self.self_attn_layer_norm.name):
self.self_attn_layer_norm.build([None, None, self.embed_dim])
if getattr(self, "encoder_attn", None) is not None:
with tf.name_scope(self.encoder_attn.name):
self.encoder_attn.build(None)
if getattr(self, "encoder_attn_layer_norm", None) is not None:
with tf.name_scope(self.encoder_attn_layer_norm.name):
self.encoder_attn_layer_norm.build([None, None, self.embed_dim])
if getattr(self, "fc1", None) is not None:
with tf.name_scope(self.fc1.name):
self.fc1.build([None, None, self.embed_dim])
if getattr(self, "fc2", None) is not None:
with tf.name_scope(self.fc2.name):
self.fc2.build([None, None, self.config.decoder_ffn_dim])
if getattr(self, "final_layer_norm", None) is not None:
with tf.name_scope(self.final_layer_norm.name):
self.final_layer_norm.build([None, None, self.embed_dim])
class TFWhisperPreTrainedModel(TFPreTrainedModel):
config_class = WhisperConfig
@ -749,6 +814,27 @@ class TFWhisperEncoder(tf.keras.layers.Layer):
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "conv1", None) is not None:
with tf.name_scope(self.conv1.name):
self.conv1.build([None, None, self.num_mel_bins])
if getattr(self, "conv2", None) is not None:
with tf.name_scope(self.conv2.name):
self.conv2.build([None, None, self.embed_dim])
if getattr(self, "embed_positions", None) is not None:
with tf.name_scope(self.embed_positions.name):
self.embed_positions.build(None)
if getattr(self, "layer_norm", None) is not None:
with tf.name_scope(self.layer_norm.name):
self.layer_norm.build([None, None, self.config.d_model])
if getattr(self, "encoder_layers", None) is not None:
for layer in self.encoder_layers:
with tf.name_scope(layer.name):
layer.build(None)
@keras_serializable
class TFWhisperDecoder(tf.keras.layers.Layer):
@ -988,6 +1074,24 @@ class TFWhisperDecoder(tf.keras.layers.Layer):
cross_attentions=all_cross_attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embed_tokens", None) is not None:
with tf.name_scope(self.embed_tokens.name):
self.embed_tokens.build(None)
if getattr(self, "embed_positions", None) is not None:
with tf.name_scope(self.embed_positions.name):
self.embed_positions.build(None)
if getattr(self, "layer_norm", None) is not None:
with tf.name_scope(self.layer_norm.name):
self.layer_norm.build([None, None, self.config.d_model])
if getattr(self, "decoder_layers", None) is not None:
for layer in self.decoder_layers:
with tf.name_scope(layer.name):
layer.build(None)
@add_start_docstrings(
"The bare Whisper Model outputting raw hidden-states without any specific head on top.",
@ -1111,6 +1215,17 @@ class TFWhisperMainLayer(tf.keras.layers.Layer):
encoder_attentions=encoder_outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
if getattr(self, "decoder", None) is not None:
with tf.name_scope(self.decoder.name):
self.decoder.build(None)
@add_start_docstrings(
"The bare Whisper Model outputting raw hidden-states without any specific head on top.",
@ -1219,6 +1334,14 @@ class TFWhisperModel(TFWhisperPreTrainedModel):
encoder_attentions=enc_attns,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "model", None) is not None:
with tf.name_scope(self.model.name):
self.model.build(None)
@add_start_docstrings(
"The Whisper Model with a language modeling head. Can be used for automatic speech recognition.",
@ -1630,3 +1753,11 @@ class TFWhisperForConditionalGeneration(TFWhisperPreTrainedModel, TFCausalLangua
"decoder_attention_mask": decoder_attention_mask,
"decoder_position_ids": decoder_position_ids,
}
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "model", None) is not None:
with tf.name_scope(self.model.name):
self.model.build(None)

View File

@ -301,6 +301,23 @@ class TFXGLMAttention(tf.keras.layers.Layer):
return attn_output, attn_weights, past_key_value
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "k_proj", None) is not None:
with tf.name_scope(self.k_proj.name):
self.k_proj.build([None, None, self.embed_dim])
if getattr(self, "q_proj", None) is not None:
with tf.name_scope(self.q_proj.name):
self.q_proj.build([None, None, self.embed_dim])
if getattr(self, "v_proj", None) is not None:
with tf.name_scope(self.v_proj.name):
self.v_proj.build([None, None, self.embed_dim])
if getattr(self, "out_proj", None) is not None:
with tf.name_scope(self.out_proj.name):
self.out_proj.build([None, None, self.embed_dim])
class TFXGLMDecoderLayer(tf.keras.layers.Layer):
def __init__(self, config: XGLMConfig, **kwargs: Any) -> None:
@ -333,6 +350,7 @@ class TFXGLMDecoderLayer(tf.keras.layers.Layer):
self.fc1 = tf.keras.layers.Dense(config.ffn_dim, name="fc1")
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
self.config = config
# Copied from transformers.models.mbart.modeling_tf_mbart.TFMBartDecoderLayer.call
def call(
@ -415,6 +433,32 @@ class TFXGLMDecoderLayer(tf.keras.layers.Layer):
present_key_value,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "self_attn", None) is not None:
with tf.name_scope(self.self_attn.name):
self.self_attn.build(None)
if getattr(self, "self_attn_layer_norm", None) is not None:
with tf.name_scope(self.self_attn_layer_norm.name):
self.self_attn_layer_norm.build([None, None, self.embed_dim])
if getattr(self, "fc1", None) is not None:
with tf.name_scope(self.fc1.name):
self.fc1.build([None, None, self.embed_dim])
if getattr(self, "fc2", None) is not None:
with tf.name_scope(self.fc2.name):
self.fc2.build([None, None, self.config.ffn_dim])
if getattr(self, "final_layer_norm", None) is not None:
with tf.name_scope(self.final_layer_norm.name):
self.final_layer_norm.build([None, None, self.embed_dim])
if getattr(self, "encoder_attn", None) is not None:
with tf.name_scope(self.encoder_attn.name):
self.encoder_attn.build(None)
if getattr(self, "encoder_attn_layer_norm", None) is not None:
with tf.name_scope(self.encoder_attn_layer_norm.name):
self.encoder_attn_layer_norm.build([None, None, self.embed_dim])
@keras_serializable
class TFXGLMMainLayer(tf.keras.layers.Layer):
@ -609,6 +653,21 @@ class TFXGLMMainLayer(tf.keras.layers.Layer):
cross_attentions=all_cross_attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "layer_norm", None) is not None:
with tf.name_scope(self.layer_norm.name):
self.layer_norm.build([None, None, self.config.d_model])
if getattr(self, "embed_tokens", None) is not None:
with tf.name_scope(self.embed_tokens.name):
self.embed_tokens.build(None)
if getattr(self, "layers", None) is not None:
for layer in self.layers:
with tf.name_scope(layer.name):
layer.build(None)
class TFXGLMPreTrainedModel(TFPreTrainedModel):
config_class = XGLMConfig
@ -792,6 +851,14 @@ class TFXGLMModel(TFXGLMPreTrainedModel):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "model", None) is not None:
with tf.name_scope(self.model.name):
self.model.build(None)
@add_start_docstrings(
"""
@ -822,6 +889,7 @@ class TFXGLMForCausalLM(TFXGLMPreTrainedModel, TFCausalLanguageModelingLoss):
kernel_initializer=get_initializer(config.init_std),
name="lm_head",
)
self.config = config
def get_output_embeddings(self):
return self.lm_head
@ -925,6 +993,17 @@ class TFXGLMForCausalLM(TFXGLMPreTrainedModel, TFCausalLanguageModelingLoss):
cross_attentions=outputs.cross_attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "model", None) is not None:
with tf.name_scope(self.model.name):
self.model.build(None)
if getattr(self, "lm_head", None) is not None:
with tf.name_scope(self.lm_head.name):
self.lm_head.build([None, None, self.config.hidden_size])
def tf_to_pt_weight_rename(self, tf_weight):
if tf_weight == "lm_head.weight":
return tf_weight, "model.embed_tokens.weight"

View File

@ -132,6 +132,7 @@ class TFXLMMultiHeadAttention(tf.keras.layers.Layer):
self.out_lin = tf.keras.layers.Dense(dim, kernel_initializer=get_initializer(config.init_std), name="out_lin")
self.dropout = tf.keras.layers.Dropout(config.attention_dropout)
self.pruned_heads = set()
self.dim = dim
def prune_heads(self, heads):
raise NotImplementedError
@ -206,6 +207,23 @@ class TFXLMMultiHeadAttention(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "q_lin", None) is not None:
with tf.name_scope(self.q_lin.name):
self.q_lin.build([None, None, self.dim])
if getattr(self, "k_lin", None) is not None:
with tf.name_scope(self.k_lin.name):
self.k_lin.build([None, None, self.dim])
if getattr(self, "v_lin", None) is not None:
with tf.name_scope(self.v_lin.name):
self.v_lin.build([None, None, self.dim])
if getattr(self, "out_lin", None) is not None:
with tf.name_scope(self.out_lin.name):
self.out_lin.build([None, None, self.dim])
class TFXLMTransformerFFN(tf.keras.layers.Layer):
def __init__(self, in_dim, dim_hidden, out_dim, config, **kwargs):
@ -215,6 +233,8 @@ class TFXLMTransformerFFN(tf.keras.layers.Layer):
self.lin2 = tf.keras.layers.Dense(out_dim, kernel_initializer=get_initializer(config.init_std), name="lin2")
self.act = get_tf_activation("gelu") if config.gelu_activation else get_tf_activation("relu")
self.dropout = tf.keras.layers.Dropout(config.dropout)
self.in_dim = in_dim
self.dim_hidden = dim_hidden
def call(self, input, training=False):
x = self.lin1(input)
@ -224,6 +244,17 @@ class TFXLMTransformerFFN(tf.keras.layers.Layer):
return x
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "lin1", None) is not None:
with tf.name_scope(self.lin1.name):
self.lin1.build([None, None, self.in_dim])
if getattr(self, "lin2", None) is not None:
with tf.name_scope(self.lin2.name):
self.lin2.build([None, None, self.dim_hidden])
@keras_serializable
class TFXLMMainLayer(tf.keras.layers.Layer):
@ -316,7 +347,10 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
if self.attentions[int(layer)].n_heads == config.n_heads:
self.prune_heads({int(layer): list(map(int, heads))})
def build(self, input_shape):
def build(self, input_shape=None):
if self.built:
return
self.built = True
with tf.name_scope("position_embeddings"):
self.position_embeddings = self.add_weight(
name="embeddings",
@ -331,8 +365,24 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
shape=[self.n_langs, self.dim],
initializer=get_initializer(self.embed_init_std),
)
super().build(input_shape)
if getattr(self, "embeddings", None) is not None:
with tf.name_scope(self.embeddings.name):
self.embeddings.build(None)
if getattr(self, "layer_norm_emb", None) is not None:
with tf.name_scope(self.layer_norm_emb.name):
self.layer_norm_emb.build([None, None, self.dim])
for layer in self.attentions:
with tf.name_scope(layer.name):
layer.build(None)
for layer in self.layer_norm1:
with tf.name_scope(layer.name):
layer.build([None, None, self.dim])
for layer in self.ffns:
with tf.name_scope(layer.name):
layer.build(None)
for layer in self.layer_norm2:
with tf.name_scope(layer.name):
layer.build([None, None, self.dim])
def get_input_embeddings(self):
return self.embeddings
@ -734,6 +784,14 @@ class TFXLMModel(TFXLMPreTrainedModel):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "transformer", None) is not None:
with tf.name_scope(self.transformer.name):
self.transformer.build(None)
class TFXLMPredLayer(tf.keras.layers.Layer):
"""
@ -871,6 +929,17 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel):
logits=outputs, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "transformer", None) is not None:
with tf.name_scope(self.transformer.name):
self.transformer.build(None)
if getattr(self, "pred_layer", None) is not None:
with tf.name_scope(self.pred_layer.name):
self.pred_layer.build(None)
@add_start_docstrings(
"""
@ -949,6 +1018,17 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat
attentions=transformer_outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "transformer", None) is not None:
with tf.name_scope(self.transformer.name):
self.transformer.build(None)
if getattr(self, "sequence_summary", None) is not None:
with tf.name_scope(self.sequence_summary.name):
self.sequence_summary.build(None)
@add_start_docstrings(
"""
@ -966,6 +1046,7 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss):
self.logits_proj = tf.keras.layers.Dense(
1, kernel_initializer=get_initializer(config.initializer_range), name="logits_proj"
)
self.config = config
@property
def dummy_inputs(self):
@ -1068,6 +1149,20 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss):
attentions=transformer_outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "transformer", None) is not None:
with tf.name_scope(self.transformer.name):
self.transformer.build(None)
if getattr(self, "sequence_summary", None) is not None:
with tf.name_scope(self.sequence_summary.name):
self.sequence_summary.build(None)
if getattr(self, "logits_proj", None) is not None:
with tf.name_scope(self.logits_proj.name):
self.logits_proj.build([None, None, self.config.num_labels])
@add_start_docstrings(
"""
@ -1086,6 +1181,7 @@ class TFXLMForTokenClassification(TFXLMPreTrainedModel, TFTokenClassificationLos
self.classifier = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.init_std), name="classifier"
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(XLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ -1148,6 +1244,17 @@ class TFXLMForTokenClassification(TFXLMPreTrainedModel, TFTokenClassificationLos
attentions=transformer_outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "transformer", None) is not None:
with tf.name_scope(self.transformer.name):
self.transformer.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.hidden_size])
@add_start_docstrings(
"""
@ -1163,6 +1270,7 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
self.qa_outputs = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.init_std), name="qa_outputs"
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(XLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ -1238,3 +1346,14 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "transformer", None) is not None:
with tf.name_scope(self.transformer.name):
self.transformer.build(None)
if getattr(self, "qa_outputs", None) is not None:
with tf.name_scope(self.qa_outputs.name):
self.qa_outputs.build([None, None, self.config.hidden_size])

View File

@ -178,7 +178,7 @@ class TFXLMRobertaEmbeddings(tf.keras.layers.Layer):
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
def build(self, input_shape: tf.TensorShape):
def build(self, input_shape=None):
with tf.name_scope("word_embeddings"):
self.weight = self.add_weight(
name="weight",
@ -200,7 +200,12 @@ class TFXLMRobertaEmbeddings(tf.keras.layers.Layer):
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
def create_position_ids_from_input_ids(self, input_ids, past_key_values_length=0):
"""
@ -273,6 +278,7 @@ class TFXLMRobertaPooler(tf.keras.layers.Layer):
activation="tanh",
name="dense",
)
self.config = config
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
# We "pool" the model by simply taking the hidden state corresponding
@ -282,6 +288,14 @@ class TFXLMRobertaPooler(tf.keras.layers.Layer):
return pooled_output
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfAttention with Bert->XLMRoberta
class TFXLMRobertaSelfAttention(tf.keras.layers.Layer):
@ -311,6 +325,7 @@ class TFXLMRobertaSelfAttention(tf.keras.layers.Layer):
self.dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
self.is_decoder = config.is_decoder
self.config = config
def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
@ -400,6 +415,20 @@ class TFXLMRobertaSelfAttention(tf.keras.layers.Layer):
outputs = outputs + (past_key_value,)
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "query", None) is not None:
with tf.name_scope(self.query.name):
self.query.build([None, None, self.config.hidden_size])
if getattr(self, "key", None) is not None:
with tf.name_scope(self.key.name):
self.key.build([None, None, self.config.hidden_size])
if getattr(self, "value", None) is not None:
with tf.name_scope(self.value.name):
self.value.build([None, None, self.config.hidden_size])
# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfOutput with Bert->XLMRoberta
class TFXLMRobertaSelfOutput(tf.keras.layers.Layer):
@ -411,6 +440,7 @@ class TFXLMRobertaSelfOutput(tf.keras.layers.Layer):
)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
self.config = config
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -419,6 +449,17 @@ class TFXLMRobertaSelfOutput(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
# Copied from transformers.models.bert.modeling_tf_bert.TFBertAttention with Bert->XLMRoberta
class TFXLMRobertaAttention(tf.keras.layers.Layer):
@ -460,6 +501,17 @@ class TFXLMRobertaAttention(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "self_attention", None) is not None:
with tf.name_scope(self.self_attention.name):
self.self_attention.build(None)
if getattr(self, "dense_output", None) is not None:
with tf.name_scope(self.dense_output.name):
self.dense_output.build(None)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertIntermediate with Bert->XLMRoberta
class TFXLMRobertaIntermediate(tf.keras.layers.Layer):
@ -474,6 +526,7 @@ class TFXLMRobertaIntermediate(tf.keras.layers.Layer):
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
else:
self.intermediate_act_fn = config.hidden_act
self.config = config
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -481,6 +534,14 @@ class TFXLMRobertaIntermediate(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
# Copied from transformers.models.bert.modeling_tf_bert.TFBertOutput with Bert->XLMRoberta
class TFXLMRobertaOutput(tf.keras.layers.Layer):
@ -492,6 +553,7 @@ class TFXLMRobertaOutput(tf.keras.layers.Layer):
)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
self.config = config
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
@ -500,6 +562,17 @@ class TFXLMRobertaOutput(tf.keras.layers.Layer):
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.intermediate_size])
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLayer with Bert->XLMRoberta
class TFXLMRobertaLayer(tf.keras.layers.Layer):
@ -587,6 +660,23 @@ class TFXLMRobertaLayer(tf.keras.layers.Layer):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "attention", None) is not None:
with tf.name_scope(self.attention.name):
self.attention.build(None)
if getattr(self, "intermediate", None) is not None:
with tf.name_scope(self.intermediate.name):
self.intermediate.build(None)
if getattr(self, "bert_output", None) is not None:
with tf.name_scope(self.bert_output.name):
self.bert_output.build(None)
if getattr(self, "crossattention", None) is not None:
with tf.name_scope(self.crossattention.name):
self.crossattention.build(None)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertEncoder with Bert->XLMRoberta
class TFXLMRobertaEncoder(tf.keras.layers.Layer):
@ -657,6 +747,15 @@ class TFXLMRobertaEncoder(tf.keras.layers.Layer):
cross_attentions=all_cross_attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "layer", None) is not None:
for layer in self.layer:
with tf.name_scope(layer.name):
layer.build(None)
@keras_serializable
# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaMainLayer with Roberta->XLMRoberta
@ -855,6 +954,20 @@ class TFXLMRobertaMainLayer(tf.keras.layers.Layer):
cross_attentions=encoder_outputs.cross_attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
if getattr(self, "pooler", None) is not None:
with tf.name_scope(self.pooler.name):
self.pooler.build(None)
if getattr(self, "embeddings", None) is not None:
with tf.name_scope(self.embeddings.name):
self.embeddings.build(None)
# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaPreTrainedModel with Roberta->XLMRoberta
class TFXLMRobertaPreTrainedModel(TFPreTrainedModel):
@ -940,6 +1053,14 @@ class TFXLMRobertaModel(TFXLMRobertaPreTrainedModel):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "roberta", None) is not None:
with tf.name_scope(self.roberta.name):
self.roberta.build(None)
# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaLMHead with Roberta->XLMRoberta
class TFXLMRobertaLMHead(tf.keras.layers.Layer):
@ -960,10 +1081,18 @@ class TFXLMRobertaLMHead(tf.keras.layers.Layer):
# an output-only bias for each token.
self.decoder = input_embeddings
def build(self, input_shape):
def build(self, input_shape=None):
self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias")
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
if getattr(self, "layer_norm", None) is not None:
with tf.name_scope(self.layer_norm.name):
self.layer_norm.build([None, None, self.config.hidden_size])
def get_output_embeddings(self):
return self.decoder
@ -1072,6 +1201,17 @@ class TFXLMRobertaForMaskedLM(TFXLMRobertaPreTrainedModel, TFMaskedLanguageModel
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "roberta", None) is not None:
with tf.name_scope(self.roberta.name):
self.roberta.build(None)
if getattr(self, "lm_head", None) is not None:
with tf.name_scope(self.lm_head.name):
self.lm_head.build(None)
@add_start_docstrings(
"XLM-RoBERTa Model with a `language modeling` head on top for CLM fine-tuning.",
@ -1199,6 +1339,17 @@ class TFXLMRobertaForCausalLM(TFXLMRobertaPreTrainedModel, TFCausalLanguageModel
cross_attentions=outputs.cross_attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "roberta", None) is not None:
with tf.name_scope(self.roberta.name):
self.roberta.build(None)
if getattr(self, "lm_head", None) is not None:
with tf.name_scope(self.lm_head.name):
self.lm_head.build(None)
# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaClassificationHead with Roberta->XLMRoberta
class TFXLMRobertaClassificationHead(tf.keras.layers.Layer):
@ -1219,6 +1370,7 @@ class TFXLMRobertaClassificationHead(tf.keras.layers.Layer):
self.out_proj = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="out_proj"
)
self.config = config
def call(self, features, training=False):
x = features[:, 0, :] # take <s> token (equiv. to [CLS])
@ -1228,6 +1380,17 @@ class TFXLMRobertaClassificationHead(tf.keras.layers.Layer):
x = self.out_proj(x)
return x
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
if getattr(self, "out_proj", None) is not None:
with tf.name_scope(self.out_proj.name):
self.out_proj.build([None, None, self.config.hidden_size])
@add_start_docstrings(
"""
@ -1305,6 +1468,17 @@ class TFXLMRobertaForSequenceClassification(TFXLMRobertaPreTrainedModel, TFSeque
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "roberta", None) is not None:
with tf.name_scope(self.roberta.name):
self.roberta.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build(None)
@add_start_docstrings(
"""
@ -1327,6 +1501,7 @@ class TFXLMRobertaForMultipleChoice(TFXLMRobertaPreTrainedModel, TFMultipleChoic
self.classifier = tf.keras.layers.Dense(
1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(
@ -1398,6 +1573,17 @@ class TFXLMRobertaForMultipleChoice(TFXLMRobertaPreTrainedModel, TFMultipleChoic
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "roberta", None) is not None:
with tf.name_scope(self.roberta.name):
self.roberta.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.hidden_size])
@add_start_docstrings(
"""
@ -1424,6 +1610,7 @@ class TFXLMRobertaForTokenClassification(TFXLMRobertaPreTrainedModel, TFTokenCla
self.classifier = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(XLM_ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ -1482,6 +1669,17 @@ class TFXLMRobertaForTokenClassification(TFXLMRobertaPreTrainedModel, TFTokenCla
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "roberta", None) is not None:
with tf.name_scope(self.roberta.name):
self.roberta.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.hidden_size])
@add_start_docstrings(
"""
@ -1503,6 +1701,7 @@ class TFXLMRobertaForQuestionAnswering(TFXLMRobertaPreTrainedModel, TFQuestionAn
self.qa_outputs = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(XLM_ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ -1574,3 +1773,14 @@ class TFXLMRobertaForQuestionAnswering(TFXLMRobertaPreTrainedModel, TFQuestionAn
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "roberta", None) is not None:
with tf.name_scope(self.roberta.name):
self.roberta.build(None)
if getattr(self, "qa_outputs", None) is not None:
with tf.name_scope(self.qa_outputs.name):
self.qa_outputs.build([None, None, self.config.hidden_size])

View File

@ -85,8 +85,9 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
self.dropout = tf.keras.layers.Dropout(config.dropout)
self.config = config
def build(self, input_shape):
def build(self, input_shape=None):
initializer = get_initializer(self.initializer_range)
self.q = self.add_weight(
shape=(self.d_model, self.n_head, self.d_head), initializer=initializer, trainable=True, name="q"
@ -115,7 +116,13 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
self.seg_embed = self.add_weight(
shape=(2, self.n_head, self.d_head), initializer=initializer, trainable=True, name="seg_embed"
)
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "layer_norm", None) is not None:
with tf.name_scope(self.layer_norm.name):
self.layer_norm.build([None, None, self.config.d_model])
def prune_heads(self, heads):
raise NotImplementedError
@ -344,6 +351,7 @@ class TFXLNetFeedForward(tf.keras.layers.Layer):
self.activation_function = get_tf_activation(config.ff_activation)
else:
self.activation_function = config.ff_activation
self.config = config
def call(self, inp, training=False):
output = inp
@ -355,6 +363,20 @@ class TFXLNetFeedForward(tf.keras.layers.Layer):
output = self.layer_norm(output + inp)
return output
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "layer_norm", None) is not None:
with tf.name_scope(self.layer_norm.name):
self.layer_norm.build([None, None, self.config.d_model])
if getattr(self, "layer_1", None) is not None:
with tf.name_scope(self.layer_1.name):
self.layer_1.build([None, None, self.config.d_model])
if getattr(self, "layer_2", None) is not None:
with tf.name_scope(self.layer_2.name):
self.layer_2.build([None, None, self.config.d_inner])
class TFXLNetLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
@ -399,6 +421,17 @@ class TFXLNetLayer(tf.keras.layers.Layer):
outputs = (output_h, output_g) + outputs[2:] # Add again attentions if there are there
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "rel_attn", None) is not None:
with tf.name_scope(self.rel_attn.name):
self.rel_attn.build(None)
if getattr(self, "ff", None) is not None:
with tf.name_scope(self.ff.name):
self.ff.build(None)
class TFXLNetLMHead(tf.keras.layers.Layer):
def __init__(self, config, input_embeddings, **kwargs):
@ -471,12 +504,22 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
self.word_embedding.weight = value
self.word_embedding.vocab_size = shape_list(value)[0]
def build(self, input_shape):
def build(self, input_shape=None):
initializer = get_initializer(self.initializer_range)
self.mask_emb = self.add_weight(
shape=(1, 1, self.d_model), initializer=initializer, trainable=True, name="mask_emb"
)
super().build(input_shape)
if self.built:
return
self.built = True
if getattr(self, "word_embedding", None) is not None:
with tf.name_scope(self.word_embedding.name):
self.word_embedding.build(None)
if getattr(self, "layer", None) is not None:
for layer in self.layer:
with tf.name_scope(layer.name):
layer.build(None)
def _prune_heads(self, heads_to_prune):
raise NotImplementedError
@ -1177,6 +1220,14 @@ class TFXLNetModel(TFXLNetPreTrainedModel):
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "transformer", None) is not None:
with tf.name_scope(self.transformer.name):
self.transformer.build(None)
@add_start_docstrings(
"""
@ -1336,6 +1387,17 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
attentions=transformer_outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "transformer", None) is not None:
with tf.name_scope(self.transformer.name):
self.transformer.build(None)
if getattr(self, "lm_loss", None) is not None:
with tf.name_scope(self.lm_loss.name):
self.lm_loss.build(None)
@add_start_docstrings(
"""
@ -1356,6 +1418,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
self.logits_proj = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="logits_proj"
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(XLNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ -1423,6 +1486,20 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
attentions=transformer_outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "transformer", None) is not None:
with tf.name_scope(self.transformer.name):
self.transformer.build(None)
if getattr(self, "sequence_summary", None) is not None:
with tf.name_scope(self.sequence_summary.name):
self.sequence_summary.build(None)
if getattr(self, "logits_proj", None) is not None:
with tf.name_scope(self.logits_proj.name):
self.logits_proj.build([None, None, self.config.d_model])
@add_start_docstrings(
"""
@ -1442,6 +1519,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
self.logits_proj = tf.keras.layers.Dense(
1, kernel_initializer=get_initializer(config.initializer_range), name="logits_proj"
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(XLNET_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
@ -1524,6 +1602,20 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
attentions=transformer_outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "transformer", None) is not None:
with tf.name_scope(self.transformer.name):
self.transformer.build(None)
if getattr(self, "sequence_summary", None) is not None:
with tf.name_scope(self.sequence_summary.name):
self.sequence_summary.build(None)
if getattr(self, "logits_proj", None) is not None:
with tf.name_scope(self.logits_proj.name):
self.logits_proj.build([None, None, self.config.d_model])
@add_start_docstrings(
"""
@ -1541,6 +1633,7 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
self.classifier = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(XLNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ -1604,6 +1697,17 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
attentions=transformer_outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "transformer", None) is not None:
with tf.name_scope(self.transformer.name):
self.transformer.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build([None, None, self.config.hidden_size])
@add_start_docstrings(
"""
@ -1619,6 +1723,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
self.qa_outputs = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
)
self.config = config
@unpack_inputs
@add_start_docstrings_to_model_forward(XLNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ -1697,3 +1802,14 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "transformer", None) is not None:
with tf.name_scope(self.transformer.name):
self.transformer.build(None)
if getattr(self, "qa_outputs", None) is not None:
with tf.name_scope(self.qa_outputs.name):
self.qa_outputs.build([None, None, self.config.hidden_size])

View File

@ -2161,16 +2161,8 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None:
# if `self.embed_tokens.load_weight_prefix` is set, runs the embedding operation with the correct name
# scope, so that its weights are registered with the desired name for loading/storing. When `tf.name_scope`
# is used with a name ending in `/`, that name replaces the current name scope.
# (embeddings with tf.name_scope: self.embed_tokens.load_weight_prefix/self.embed_tokens.name/embeddings:0)
context = []
if hasattr(self.embed_tokens, "load_weight_prefix"):
context.append(tf.name_scope(self.embed_tokens.load_weight_prefix + "/"))
with ContextManagers(context):
check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim)
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim)
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
embed_pos = self.embed_positions(input_shape)
hidden_states = inputs_embeds + embed_pos
@ -2359,16 +2351,8 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
positions = self.embed_positions(input_shape, past_key_values_length)
if inputs_embeds is None:
# if `self.embed_tokens.load_weight_prefix` is set, runs the embedding operation with the correct name
# scope, so that its weights are registered with the desired name for loading/storing. When `tf.name_scope`
# is used with a name ending in `/`, that name replaces the current name scope.
# (embeddings with tf.name_scope: self.embed_tokens.load_weight_prefix/self.embed_tokens.name/embeddings:0)
context = []
if hasattr(self.embed_tokens, "load_weight_prefix"):
context.append(tf.name_scope(self.embed_tokens.load_weight_prefix + "/"))
with ContextManagers(context):
check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim)
inputs_embeds = self.embed_tokens(input_ids)
check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim)
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
@ -2578,6 +2562,13 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
encoder_attentions=encoder_outputs.attentions,
)
def build(self, input_shape=None):
# The shared/tied weights expect to be in the model base namespace
# Adding "/" to the end (not the start!) of a tf.name_scope puts it in the root namespace rather than
# the current one.
with tf.name_scope(self.shared.load_weight_prefix + '/' + self.shared.name + '/'):
self.shared.build(None)
@add_start_docstrings(
"The bare {{cookiecutter.uppercase_modelname}} Model outputting raw hidden-states without any specific head on top.",

View File

@ -1071,9 +1071,9 @@ class TFEncoderDecoderModelSaveLoadTests(unittest.TestCase):
# create two random BERT models for bert2bert & initialize weights (+cross_attention weights)
encoder = TFBertModel(config.encoder)
encoder.build()
encoder.build_in_name_scope()
decoder = TFBertLMHeadModel(config.decoder)
decoder.build()
decoder.build_in_name_scope()
encoder_decoder_orig = TFEncoderDecoderModel(encoder=encoder, decoder=decoder)

View File

@ -180,7 +180,7 @@ class TFOPTModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
else:
# Here we build the word embeddings weights if not exists.
# And then we retry to get the attribute once built.
model.build()
model.build_in_name_scope()
if hasattr(embedding_layer, "weight"):
return embedding_layer.weight
else:

View File

@ -729,9 +729,9 @@ class TFVisionEncoderDecoderModelSaveLoadTests(unittest.TestCase):
# create two random ViT/GPT2 models for vit-gpt2 & initialize weights (+cross_attention weights)
encoder = TFViTModel(config.encoder)
encoder.build()
encoder.build_in_name_scope()
decoder = TFGPT2LMHeadModel(config.decoder)
decoder.build()
decoder.build_in_name_scope()
encoder_decoder_orig = TFVisionEncoderDecoderModel(encoder=encoder, decoder=decoder)

View File

@ -290,7 +290,7 @@ class TFWhisperModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestC
for model_class in self.all_model_classes:
model = model_class(config)
model.build()
model.build_in_name_scope()
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname, saved_model=False)

View File

@ -21,7 +21,7 @@ from transformers import (
TFPreTrainedModel,
pipeline,
)
from transformers.testing_utils import get_gpu_count, is_pipeline_test, require_tf, require_torch, slow, torch_device
from transformers.testing_utils import is_pipeline_test, require_tf, require_torch, slow, torch_device
from transformers.tokenization_utils import TruncationStrategy
from .test_pipelines_common import ANY
@ -67,8 +67,8 @@ class SummarizationPipelineTests(unittest.TestCase):
# the embedding layer.
if not (
isinstance(model, TFPreTrainedModel)
and get_gpu_count() > 0
and len(summarizer.model.trainable_weights) > 0
and "GPU" in summarizer.model.trainable_weights[0].device
):
with self.assertRaises(Exception):
outputs = summarizer("This " * 1000)