mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Update the TF models to remove their interdependencies (#7238)
* Refacto the models to remove their interdependencies * Fix Flaubert model * Fix Flaubert * Fix XLM * Fix Albert * Fix Roberta * Fix Albert * Fix Flaubert * Apply style + remove unused imports * Fix Distilbert * remove unused import * fix Distilbert * Fix Flaubert * Apply style * Fix Flaubert * Add the copy comments for the check_copies script * Fix MobileBert model name * Address Morgan's comments * Fix typo * Oops typo
This commit is contained in:
parent
0cffa424f8
commit
d161ed1682
@ -31,7 +31,6 @@ from .file_utils import (
|
||||
add_start_docstrings_to_callable,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from .modeling_tf_bert import TFBertSelfAttention
|
||||
from .modeling_tf_outputs import (
|
||||
TFBaseModelOutput,
|
||||
TFBaseModelOutputWithPooling,
|
||||
@ -181,82 +180,6 @@ class TFAlbertEmbeddings(tf.keras.layers.Layer):
|
||||
return tf.reshape(logits, [batch_size, length, self.config.vocab_size])
|
||||
|
||||
|
||||
class TFAlbertSelfAttention(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if config.hidden_size % config.num_attention_heads != 0:
|
||||
raise ValueError(
|
||||
"The hidden size (%d) is not a multiple of the number of attention "
|
||||
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
|
||||
)
|
||||
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
assert (
|
||||
config.hidden_size % config.num_attention_heads == 0
|
||||
), f"Hidden size {config.hidden_size} not dividable by number of heads {config.num_attention_heads}"
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
self.output_attentions = config.output_attentions
|
||||
|
||||
self.query = tf.keras.layers.Dense(
|
||||
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
|
||||
)
|
||||
self.key = tf.keras.layers.Dense(
|
||||
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
|
||||
)
|
||||
self.value = tf.keras.layers.Dense(
|
||||
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
|
||||
)
|
||||
|
||||
self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
def transpose_for_scores(self, x, batch_size):
|
||||
x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size))
|
||||
return tf.transpose(x, perm=[0, 2, 1, 3])
|
||||
|
||||
def call(self, hidden_states, attention_mask, head_mask, output_attentions, training=False):
|
||||
batch_size = shape_list(hidden_states)[0]
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
mixed_key_layer = self.key(hidden_states)
|
||||
mixed_value_layer = self.value(hidden_states)
|
||||
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
|
||||
key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
|
||||
value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
# (batch size, num_heads, seq_len_q, seq_len_k)
|
||||
attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
|
||||
# scale attention_scores
|
||||
dk = tf.cast(shape_list(key_layer)[-1], tf.float32)
|
||||
attention_scores = attention_scores / tf.math.sqrt(dk)
|
||||
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask is (precomputed for all layers in TFAlbertModel call() function)
|
||||
attention_scores = attention_scores + attention_mask
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = tf.nn.softmax(attention_scores, axis=-1)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs, training=training)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
context_layer = tf.matmul(attention_probs, value_layer)
|
||||
|
||||
context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3])
|
||||
context_layer = tf.reshape(
|
||||
context_layer, (batch_size, -1, self.all_head_size)
|
||||
) # (batch_size, seq_len_q, all_head_size)
|
||||
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
return outputs
|
||||
|
||||
|
||||
class TFAlbertSelfOutput(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
@ -273,7 +196,7 @@ class TFAlbertSelfOutput(tf.keras.layers.Layer):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class TFAlbertAttention(TFBertSelfAttention):
|
||||
class TFAlbertAttention(tf.keras.layers.Layer):
|
||||
""" Contains the complete attention sublayer, including both dropouts and layer norm. """
|
||||
|
||||
def __init__(self, config, **kwargs):
|
||||
@ -281,6 +204,19 @@ class TFAlbertAttention(TFBertSelfAttention):
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
self.output_attentions = config.output_attentions
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
assert config.hidden_size % config.num_attention_heads == 0
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
self.query = tf.keras.layers.Dense(
|
||||
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
|
||||
)
|
||||
self.key = tf.keras.layers.Dense(
|
||||
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
|
||||
)
|
||||
self.value = tf.keras.layers.Dense(
|
||||
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
|
||||
)
|
||||
self.dense = tf.keras.layers.Dense(
|
||||
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
||||
)
|
||||
@ -290,6 +226,11 @@ class TFAlbertAttention(TFBertSelfAttention):
|
||||
self.attention_dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)
|
||||
self.output_dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def transpose_for_scores(self, x, batch_size):
|
||||
x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size))
|
||||
|
||||
return tf.transpose(x, perm=[0, 2, 1, 3])
|
||||
|
||||
def prune_heads(self, heads):
|
||||
raise NotImplementedError
|
||||
|
||||
@ -342,6 +283,7 @@ class TFAlbertAttention(TFBertSelfAttention):
|
||||
|
||||
# add attentions if we output them
|
||||
outputs = (attention_output,) + self_outputs[1:]
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
|
@ -93,6 +93,7 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
|
||||
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.vocab_size = config.vocab_size
|
||||
self.hidden_size = config.hidden_size
|
||||
self.initializer_range = config.initializer_range
|
||||
@ -124,6 +125,7 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
|
||||
shape=[self.vocab_size, self.hidden_size],
|
||||
initializer=get_initializer(self.initializer_range),
|
||||
)
|
||||
|
||||
super().build(input_shape)
|
||||
|
||||
def call(
|
||||
@ -273,6 +275,7 @@ class TFBertSelfAttention(tf.keras.layers.Layer):
|
||||
class TFBertSelfOutput(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.dense = tf.keras.layers.Dense(
|
||||
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
||||
)
|
||||
@ -290,6 +293,7 @@ class TFBertSelfOutput(tf.keras.layers.Layer):
|
||||
class TFBertAttention(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.self_attention = TFBertSelfAttention(config, name="self")
|
||||
self.dense_output = TFBertSelfOutput(config, name="output")
|
||||
|
||||
@ -309,6 +313,7 @@ class TFBertAttention(tf.keras.layers.Layer):
|
||||
class TFBertIntermediate(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.dense = tf.keras.layers.Dense(
|
||||
config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
||||
)
|
||||
@ -328,6 +333,7 @@ class TFBertIntermediate(tf.keras.layers.Layer):
|
||||
class TFBertOutput(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.dense = tf.keras.layers.Dense(
|
||||
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
||||
)
|
||||
@ -345,6 +351,7 @@ class TFBertOutput(tf.keras.layers.Layer):
|
||||
class TFBertLayer(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.attention = TFBertAttention(config, name="attention")
|
||||
self.intermediate = TFBertIntermediate(config, name="intermediate")
|
||||
self.bert_output = TFBertOutput(config, name="output")
|
||||
@ -364,6 +371,7 @@ class TFBertLayer(tf.keras.layers.Layer):
|
||||
class TFBertEncoder(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.layer = [TFBertLayer(config, name="layer_._{}".format(i)) for i in range(config.num_hidden_layers)]
|
||||
|
||||
def call(
|
||||
@ -397,6 +405,7 @@ class TFBertEncoder(tf.keras.layers.Layer):
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
|
||||
|
||||
return TFBaseModelOutput(
|
||||
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
|
||||
)
|
||||
@ -405,6 +414,7 @@ class TFBertEncoder(tf.keras.layers.Layer):
|
||||
class TFBertPooler(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.dense = tf.keras.layers.Dense(
|
||||
config.hidden_size,
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
@ -424,6 +434,7 @@ class TFBertPooler(tf.keras.layers.Layer):
|
||||
class TFBertPredictionHeadTransform(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.dense = tf.keras.layers.Dense(
|
||||
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
||||
)
|
||||
@ -446,6 +457,7 @@ class TFBertPredictionHeadTransform(tf.keras.layers.Layer):
|
||||
class TFBertLMPredictionHead(tf.keras.layers.Layer):
|
||||
def __init__(self, config, input_embeddings, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.vocab_size = config.vocab_size
|
||||
self.transform = TFBertPredictionHeadTransform(config, name="transform")
|
||||
|
||||
@ -455,6 +467,7 @@ class TFBertLMPredictionHead(tf.keras.layers.Layer):
|
||||
|
||||
def build(self, input_shape):
|
||||
self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias")
|
||||
|
||||
super().build(input_shape)
|
||||
|
||||
def call(self, hidden_states):
|
||||
@ -468,6 +481,7 @@ class TFBertLMPredictionHead(tf.keras.layers.Layer):
|
||||
class TFBertMLMHead(tf.keras.layers.Layer):
|
||||
def __init__(self, config, input_embeddings, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.predictions = TFBertLMPredictionHead(config, input_embeddings, name="predictions")
|
||||
|
||||
def call(self, sequence_output):
|
||||
@ -479,6 +493,7 @@ class TFBertMLMHead(tf.keras.layers.Layer):
|
||||
class TFBertNSPHead(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.seq_relationship = tf.keras.layers.Dense(
|
||||
2, kernel_initializer=get_initializer(config.initializer_range), name="seq_relationship"
|
||||
)
|
||||
@ -495,6 +510,7 @@ class TFBertMainLayer(tf.keras.layers.Layer):
|
||||
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.num_hidden_layers = config.num_hidden_layers
|
||||
self.initializer_range = config.initializer_range
|
||||
self.output_attentions = config.output_attentions
|
||||
@ -571,6 +587,7 @@ class TFBertMainLayer(tf.keras.layers.Layer):
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = tf.fill(input_shape, 1)
|
||||
|
||||
if token_type_ids is None:
|
||||
token_type_ids = tf.fill(input_shape, 0)
|
||||
|
||||
@ -588,7 +605,6 @@ class TFBertMainLayer(tf.keras.layers.Layer):
|
||||
# positions we want to attend and -10000.0 for masked positions.
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
|
||||
extended_attention_mask = tf.cast(extended_attention_mask, embedding_output.dtype)
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
|
||||
@ -767,6 +783,7 @@ BERT_INPUTS_DOCSTRING = r"""
|
||||
class TFBertModel(TFBertPreTrainedModel):
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
|
||||
self.bert = TFBertMainLayer(config, name="bert")
|
||||
|
||||
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@ -778,6 +795,7 @@ class TFBertModel(TFBertPreTrainedModel):
|
||||
)
|
||||
def call(self, inputs, **kwargs):
|
||||
outputs = self.bert(inputs, **kwargs)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@ -818,7 +836,6 @@ class TFBertForPreTraining(TFBertPreTrainedModel):
|
||||
return_dict = kwargs.get("return_dict")
|
||||
return_dict = return_dict if return_dict is not None else self.bert.return_dict
|
||||
outputs = self.bert(inputs, **kwargs)
|
||||
|
||||
sequence_output, pooled_output = outputs[:2]
|
||||
prediction_scores = self.mlm(sequence_output, training=kwargs.get("training", False))
|
||||
seq_relationship_score = self.nsp(pooled_output)
|
||||
@ -880,6 +897,7 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
|
||||
in ``[0, ..., config.vocab_size]``
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.bert.return_dict
|
||||
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
labels = inputs[9] if len(inputs) > 9 else labels
|
||||
if len(inputs) > 9:
|
||||
@ -902,7 +920,6 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
|
||||
|
||||
sequence_output = outputs[0]
|
||||
prediction_scores = self.mlm(sequence_output, training=training)
|
||||
|
||||
loss = None if labels is None else self.compute_loss(labels, prediction_scores)
|
||||
|
||||
if not return_dict:
|
||||
@ -956,6 +973,7 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
Indices should be in ``[0, ..., config.vocab_size - 1]``.
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.bert.return_dict
|
||||
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
labels = inputs[9] if len(inputs) > 9 else labels
|
||||
if len(inputs) > 9:
|
||||
@ -978,8 +996,8 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
|
||||
sequence_output = outputs[0]
|
||||
logits = self.mlm(sequence_output, training=training)
|
||||
|
||||
loss = None
|
||||
|
||||
if labels is not None:
|
||||
# shift labels to the left and cut last logit token
|
||||
logits = logits[:, :-1]
|
||||
@ -1033,7 +1051,6 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel):
|
||||
return_dict = kwargs.get("return_dict")
|
||||
return_dict = return_dict if return_dict is not None else self.bert.return_dict
|
||||
outputs = self.bert(inputs, **kwargs)
|
||||
|
||||
pooled_output = outputs[1]
|
||||
seq_relationship_score = self.nsp(pooled_output)
|
||||
|
||||
@ -1055,8 +1072,8 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel):
|
||||
class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassificationLoss):
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.num_labels = config.num_labels
|
||||
self.bert = TFBertMainLayer(config, name="bert")
|
||||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = tf.keras.layers.Dense(
|
||||
@ -1092,6 +1109,7 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
|
||||
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.bert.return_dict
|
||||
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
labels = inputs[9] if len(inputs) > 9 else labels
|
||||
if len(inputs) > 9:
|
||||
@ -1113,10 +1131,8 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
|
||||
)
|
||||
|
||||
pooled_output = outputs[1]
|
||||
|
||||
pooled_output = self.dropout(pooled_output, training=training)
|
||||
logits = self.classifier(pooled_output)
|
||||
|
||||
loss = None if labels is None else self.compute_loss(labels, logits)
|
||||
|
||||
if not return_dict:
|
||||
@ -1208,6 +1224,7 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
|
||||
assert len(inputs) <= 10, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.bert.return_dict
|
||||
|
||||
if input_ids is not None:
|
||||
@ -1242,7 +1259,6 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
|
||||
pooled_output = self.dropout(pooled_output, training=training)
|
||||
logits = self.classifier(pooled_output)
|
||||
reshaped_logits = tf.reshape(logits, (-1, num_choices))
|
||||
|
||||
loss = None if labels is None else self.compute_loss(labels, reshaped_logits)
|
||||
|
||||
if not return_dict:
|
||||
@ -1265,8 +1281,8 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
|
||||
class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationLoss):
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.num_labels = config.num_labels
|
||||
self.bert = TFBertMainLayer(config, name="bert")
|
||||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = tf.keras.layers.Dense(
|
||||
@ -1300,6 +1316,7 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
|
||||
Indices should be in ``[0, ..., config.num_labels - 1]``.
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.bert.return_dict
|
||||
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
labels = inputs[9] if len(inputs) > 9 else labels
|
||||
if len(inputs) > 9:
|
||||
@ -1319,12 +1336,9 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
|
||||
sequence_output = self.dropout(sequence_output, training=training)
|
||||
logits = self.classifier(sequence_output)
|
||||
|
||||
loss = None if labels is None else self.compute_loss(labels, logits)
|
||||
|
||||
if not return_dict:
|
||||
@ -1347,8 +1361,8 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
|
||||
class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss):
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.num_labels = config.num_labels
|
||||
self.bert = TFBertMainLayer(config, name="bert")
|
||||
self.qa_outputs = tf.keras.layers.Dense(
|
||||
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
|
||||
@ -1387,6 +1401,7 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss)
|
||||
Position outside of the sequence are not taken into account for computing the loss.
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.bert.return_dict
|
||||
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
start_positions = inputs[9] if len(inputs) > 9 else start_positions
|
||||
end_positions = inputs[10] if len(inputs) > 10 else end_positions
|
||||
@ -1408,15 +1423,13 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss)
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
|
||||
logits = self.qa_outputs(sequence_output)
|
||||
start_logits, end_logits = tf.split(logits, 2, axis=-1)
|
||||
start_logits = tf.squeeze(start_logits, axis=-1)
|
||||
end_logits = tf.squeeze(end_logits, axis=-1)
|
||||
|
||||
loss = None
|
||||
|
||||
if start_positions is not None and end_positions is not None:
|
||||
labels = {"start_position": start_positions}
|
||||
labels["end_position"] = end_positions
|
||||
|
@ -16,8 +16,6 @@
|
||||
"""
|
||||
|
||||
|
||||
import math
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from .activations_tf import get_tf_activation
|
||||
@ -217,9 +215,8 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer):
|
||||
k_length = shape_list(key)[1]
|
||||
# assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
|
||||
# assert key.size() == value.size()
|
||||
|
||||
dim_per_head = self.dim // self.n_heads
|
||||
|
||||
dim_per_head = tf.math.divide(self.dim, self.n_heads)
|
||||
dim_per_head = tf.cast(dim_per_head, dtype=tf.int32)
|
||||
mask_reshape = [bs, 1, 1, k_length]
|
||||
|
||||
def shape(x):
|
||||
@ -233,17 +230,16 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer):
|
||||
q = shape(self.q_lin(query)) # (bs, n_heads, q_length, dim_per_head)
|
||||
k = shape(self.k_lin(key)) # (bs, n_heads, k_length, dim_per_head)
|
||||
v = shape(self.v_lin(value)) # (bs, n_heads, k_length, dim_per_head)
|
||||
|
||||
q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head)
|
||||
q = tf.cast(q, dtype=tf.float32)
|
||||
q = tf.multiply(q, tf.math.rsqrt(tf.cast(dim_per_head, dtype=tf.float32)))
|
||||
k = tf.cast(k, dtype=q.dtype)
|
||||
scores = tf.matmul(q, k, transpose_b=True) # (bs, n_heads, q_length, k_length)
|
||||
mask = tf.reshape(mask, mask_reshape) # (bs, n_heads, qlen, klen)
|
||||
# scores.masked_fill_(mask, -float('inf')) # (bs, n_heads, q_length, k_length)
|
||||
|
||||
scores_dtype = scores.dtype
|
||||
# calculate `scores` in `tf.float32` to avoid numeric overflow
|
||||
scores = tf.cast(scores, dtype=tf.float32) - 1e30 * (1.0 - tf.cast(mask, dtype=tf.float32))
|
||||
|
||||
weights = tf.cast(tf.nn.softmax(scores, axis=-1), dtype=scores_dtype) # (bs, n_heads, qlen, klen)
|
||||
mask = tf.cast(mask, dtype=scores.dtype)
|
||||
scores = scores - 1e30 * (1.0 - mask)
|
||||
weights = tf.nn.softmax(scores, axis=-1) # (bs, n_heads, qlen, klen)
|
||||
weights = self.dropout(weights, training=training) # (bs, n_heads, qlen, klen)
|
||||
|
||||
# Mask heads if we want to
|
||||
|
@ -13,7 +13,6 @@ from .file_utils import (
|
||||
add_start_docstrings_to_callable,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from .modeling_tf_bert import TFBertEncoder, TFBertPreTrainedModel
|
||||
from .modeling_tf_outputs import (
|
||||
TFBaseModelOutput,
|
||||
TFMaskedLMOutput,
|
||||
@ -25,6 +24,7 @@ from .modeling_tf_outputs import (
|
||||
from .modeling_tf_utils import (
|
||||
TFMaskedLanguageModelingLoss,
|
||||
TFMultipleChoiceLoss,
|
||||
TFPreTrainedModel,
|
||||
TFQuestionAnsweringLoss,
|
||||
TFSequenceClassificationLoss,
|
||||
TFSequenceSummary,
|
||||
@ -53,15 +53,253 @@ TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
]
|
||||
|
||||
|
||||
# Copied from transformers.modeling_tf_bert.TFBertSelfAttention
|
||||
class TFElectraSelfAttention(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if config.hidden_size % config.num_attention_heads != 0:
|
||||
raise ValueError(
|
||||
"The hidden size (%d) is not a multiple of the number of attention "
|
||||
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
|
||||
)
|
||||
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
assert config.hidden_size % config.num_attention_heads == 0
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
self.query = tf.keras.layers.Dense(
|
||||
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
|
||||
)
|
||||
self.key = tf.keras.layers.Dense(
|
||||
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
|
||||
)
|
||||
self.value = tf.keras.layers.Dense(
|
||||
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
|
||||
)
|
||||
self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
def transpose_for_scores(self, x, batch_size):
|
||||
x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size))
|
||||
|
||||
return tf.transpose(x, perm=[0, 2, 1, 3])
|
||||
|
||||
def call(self, hidden_states, attention_mask, head_mask, output_attentions, training=False):
|
||||
batch_size = shape_list(hidden_states)[0]
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
mixed_key_layer = self.key(hidden_states)
|
||||
mixed_value_layer = self.value(hidden_states)
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
|
||||
key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
|
||||
value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = tf.matmul(
|
||||
query_layer, key_layer, transpose_b=True
|
||||
) # (batch size, num_heads, seq_len_q, seq_len_k)
|
||||
dk = tf.cast(shape_list(key_layer)[-1], attention_scores.dtype) # scale attention_scores
|
||||
attention_scores = attention_scores / tf.math.sqrt(dk)
|
||||
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask is (precomputed for all layers in TFBertModel call() function)
|
||||
attention_scores = attention_scores + attention_mask
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = tf.nn.softmax(attention_scores, axis=-1)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs, training=training)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
context_layer = tf.matmul(attention_probs, value_layer)
|
||||
context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3])
|
||||
context_layer = tf.reshape(
|
||||
context_layer, (batch_size, -1, self.all_head_size)
|
||||
) # (batch_size, seq_len_q, all_head_size)
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.modeling_tf_bert.TFBertSelfOutput
|
||||
class TFElectraSelfOutput(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.dense = tf.keras.layers.Dense(
|
||||
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
||||
)
|
||||
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
|
||||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def call(self, hidden_states, input_tensor, training=False):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from from transformers.modeling_tf_bert.TFBertAttention with Bert->Electra
|
||||
class TFElectraAttention(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.self_attention = TFElectraSelfAttention(config, name="self")
|
||||
self.dense_output = TFElectraSelfOutput(config, name="output")
|
||||
|
||||
def prune_heads(self, heads):
|
||||
raise NotImplementedError
|
||||
|
||||
def call(self, input_tensor, attention_mask, head_mask, output_attentions, training=False):
|
||||
self_outputs = self.self_attention(
|
||||
input_tensor, attention_mask, head_mask, output_attentions, training=training
|
||||
)
|
||||
attention_output = self.dense_output(self_outputs[0], input_tensor, training=training)
|
||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.modeling_tf_bert.TFBertIntermediate
|
||||
class TFElectraIntermediate(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.dense = tf.keras.layers.Dense(
|
||||
config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
||||
)
|
||||
|
||||
if isinstance(config.hidden_act, str):
|
||||
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
|
||||
else:
|
||||
self.intermediate_act_fn = config.hidden_act
|
||||
|
||||
def call(self, hidden_states):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.modeling_tf_bert.TFBertOutput
|
||||
class TFElectraOutput(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.dense = tf.keras.layers.Dense(
|
||||
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
||||
)
|
||||
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
|
||||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def call(self, hidden_states, input_tensor, training=False):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.modeling_tf_bert.TFBertLayer with Bert->Electra
|
||||
class TFElectraLayer(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.attention = TFElectraAttention(config, name="attention")
|
||||
self.intermediate = TFElectraIntermediate(config, name="intermediate")
|
||||
self.bert_output = TFElectraOutput(config, name="output")
|
||||
|
||||
def call(self, hidden_states, attention_mask, head_mask, output_attentions, training=False):
|
||||
attention_outputs = self.attention(
|
||||
hidden_states, attention_mask, head_mask, output_attentions, training=training
|
||||
)
|
||||
attention_output = attention_outputs[0]
|
||||
intermediate_output = self.intermediate(attention_output)
|
||||
layer_output = self.bert_output(intermediate_output, attention_output, training=training)
|
||||
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.modeling_tf_bert.TFBertEncoder with Bert->Electra
|
||||
class TFElectraEncoder(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.layer = [TFElectraLayer(config, name="layer_._{}".format(i)) for i in range(config.num_hidden_layers)]
|
||||
|
||||
def call(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
return_dict,
|
||||
training=False,
|
||||
):
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_attentions = () if output_attentions else None
|
||||
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer_outputs = layer_module(
|
||||
hidden_states, attention_mask, head_mask[i], output_attentions, training=training
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_attentions = all_attentions + (layer_outputs[1],)
|
||||
|
||||
# Add last layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
|
||||
|
||||
return TFBaseModelOutput(
|
||||
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.modeling_tf_bert.TFBertPooler
|
||||
class TFElectraPooler(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.dense = tf.keras.layers.Dense(
|
||||
config.hidden_size,
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
activation="tanh",
|
||||
name="dense",
|
||||
)
|
||||
|
||||
def call(self, hidden_states):
|
||||
# We "pool" the model by simply taking the hidden state corresponding
|
||||
# to the first token.
|
||||
first_token_tensor = hidden_states[:, 0]
|
||||
pooled_output = self.dense(first_token_tensor)
|
||||
|
||||
return pooled_output
|
||||
|
||||
|
||||
class TFElectraEmbeddings(tf.keras.layers.Layer):
|
||||
"""Construct the embeddings from word, position and token_type embeddings."""
|
||||
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.vocab_size = config.vocab_size
|
||||
self.embedding_size = config.embedding_size
|
||||
self.initializer_range = config.initializer_range
|
||||
|
||||
self.position_embeddings = tf.keras.layers.Embedding(
|
||||
config.max_position_embeddings,
|
||||
config.embedding_size,
|
||||
@ -90,11 +328,13 @@ class TFElectraEmbeddings(tf.keras.layers.Layer):
|
||||
shape=[self.vocab_size, self.embedding_size],
|
||||
initializer=get_initializer(self.initializer_range),
|
||||
)
|
||||
|
||||
super().build(input_shape)
|
||||
|
||||
# Copied from transformers.modeling_tf_bert.TFBertEmbeddings.call
|
||||
def call(
|
||||
self,
|
||||
input_ids,
|
||||
input_ids=None,
|
||||
position_ids=None,
|
||||
token_type_ids=None,
|
||||
inputs_embeds=None,
|
||||
@ -122,6 +362,7 @@ class TFElectraEmbeddings(tf.keras.layers.Layer):
|
||||
else:
|
||||
raise ValueError("mode {} is not valid.".format(mode))
|
||||
|
||||
# Copied from transformers.modeling_tf_bert.TFBertEmbeddings._embedding
|
||||
def _embedding(self, input_ids, position_ids, token_type_ids, inputs_embeds, training=False):
|
||||
"""Applies embedding based on inputs tensor."""
|
||||
assert not (input_ids is None and inputs_embeds is None)
|
||||
@ -132,19 +373,22 @@ class TFElectraEmbeddings(tf.keras.layers.Layer):
|
||||
input_shape = shape_list(inputs_embeds)[:-1]
|
||||
|
||||
seq_length = input_shape[1]
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = tf.range(seq_length, dtype=tf.int32)[tf.newaxis, :]
|
||||
|
||||
if token_type_ids is None:
|
||||
token_type_ids = tf.fill(input_shape, 0)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = tf.gather(self.word_embeddings, input_ids)
|
||||
|
||||
position_embeddings = tf.cast(self.position_embeddings(position_ids), inputs_embeds.dtype)
|
||||
token_type_embeddings = tf.cast(self.token_type_embeddings(token_type_ids), inputs_embeds.dtype)
|
||||
|
||||
embeddings = inputs_embeds + position_embeddings + token_type_embeddings
|
||||
embeddings = self.LayerNorm(embeddings)
|
||||
embeddings = self.dropout(embeddings, training=training)
|
||||
|
||||
return embeddings
|
||||
|
||||
def _linear(self, inputs):
|
||||
@ -156,7 +400,6 @@ class TFElectraEmbeddings(tf.keras.layers.Layer):
|
||||
"""
|
||||
batch_size = shape_list(inputs)[0]
|
||||
length = shape_list(inputs)[1]
|
||||
|
||||
x = tf.reshape(inputs, [-1, self.embedding_size])
|
||||
logits = tf.matmul(x, self.word_embeddings, transpose_b=True)
|
||||
|
||||
@ -194,54 +437,28 @@ class TFElectraGeneratorPredictions(tf.keras.layers.Layer):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class TFElectraPreTrainedModel(TFBertPreTrainedModel):
|
||||
class TFElectraPreTrainedModel(TFPreTrainedModel):
|
||||
"""An abstract class to handle weights initialization and
|
||||
a simple interface for downloading and loading pretrained models.
|
||||
"""
|
||||
|
||||
config_class = ElectraConfig
|
||||
base_model_prefix = "electra"
|
||||
|
||||
def get_extended_attention_mask(self, attention_mask, input_shape, dtype):
|
||||
if attention_mask is None:
|
||||
attention_mask = tf.fill(input_shape, 1)
|
||||
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
# positions we want to attend and -10000.0 for masked positions.
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
|
||||
extended_attention_mask = tf.cast(extended_attention_mask, dtype)
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
|
||||
return extended_attention_mask
|
||||
|
||||
def get_head_mask(self, head_mask):
|
||||
if head_mask is not None:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
head_mask = [None] * self.config.num_hidden_layers
|
||||
|
||||
return head_mask
|
||||
|
||||
|
||||
@keras_serializable
|
||||
class TFElectraMainLayer(TFElectraPreTrainedModel):
|
||||
|
||||
class TFElectraMainLayer(tf.keras.layers.Layer):
|
||||
config_class = ElectraConfig
|
||||
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(config, **kwargs)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.embeddings = TFElectraEmbeddings(config, name="embeddings")
|
||||
|
||||
if config.embedding_size != config.hidden_size:
|
||||
self.embeddings_project = tf.keras.layers.Dense(config.hidden_size, name="embeddings_project")
|
||||
self.encoder = TFBertEncoder(config, name="encoder")
|
||||
|
||||
self.encoder = TFElectraEncoder(config, name="encoder")
|
||||
self.config = config
|
||||
|
||||
def get_input_embeddings(self):
|
||||
@ -261,6 +478,35 @@ class TFElectraMainLayer(TFElectraPreTrainedModel):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_extended_attention_mask(self, attention_mask, input_shape, dtype):
|
||||
if attention_mask is None:
|
||||
attention_mask = tf.fill(input_shape, 1)
|
||||
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
# positions we want to attend and -10000.0 for masked positions.
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
extended_attention_mask = tf.cast(extended_attention_mask, dtype)
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
|
||||
return extended_attention_mask
|
||||
|
||||
def get_head_mask(self, head_mask):
|
||||
if head_mask is not None:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
head_mask = [None] * self.config.num_hidden_layers
|
||||
|
||||
return head_mask
|
||||
|
||||
def call(
|
||||
self,
|
||||
inputs,
|
||||
@ -316,11 +562,11 @@ class TFElectraMainLayer(TFElectraPreTrainedModel):
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = tf.fill(input_shape, 1)
|
||||
|
||||
if token_type_ids is None:
|
||||
token_type_ids = tf.fill(input_shape, 0)
|
||||
|
||||
hidden_states = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
|
||||
|
||||
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, hidden_states.dtype)
|
||||
head_mask = self.get_head_mask(head_mask)
|
||||
|
||||
@ -462,6 +708,7 @@ ELECTRA_INPUTS_DOCSTRING = r"""
|
||||
class TFElectraModel(TFElectraPreTrainedModel):
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
|
||||
self.electra = TFElectraMainLayer(config, name="electra")
|
||||
|
||||
@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@ -473,6 +720,7 @@ class TFElectraModel(TFElectraPreTrainedModel):
|
||||
)
|
||||
def call(self, inputs, **kwargs):
|
||||
outputs = self.electra(inputs, **kwargs)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@ -521,7 +769,6 @@ class TFElectraForPreTraining(TFElectraPreTrainedModel):
|
||||
>>> scores = outputs[0]
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.electra.config.return_dict
|
||||
|
||||
discriminator_hidden_states = self.electra(
|
||||
input_ids,
|
||||
attention_mask,
|
||||
@ -550,16 +797,19 @@ class TFElectraForPreTraining(TFElectraPreTrainedModel):
|
||||
class TFElectraMaskedLMHead(tf.keras.layers.Layer):
|
||||
def __init__(self, config, input_embeddings, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.vocab_size = config.vocab_size
|
||||
self.input_embeddings = input_embeddings
|
||||
|
||||
def build(self, input_shape):
|
||||
self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias")
|
||||
|
||||
super().build(input_shape)
|
||||
|
||||
def call(self, hidden_states, training=False):
|
||||
hidden_states = self.input_embeddings(hidden_states, mode="linear")
|
||||
hidden_states = hidden_states + self.bias
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@ -577,10 +827,12 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel, TFMaskedLanguageModelingLos
|
||||
self.vocab_size = config.vocab_size
|
||||
self.electra = TFElectraMainLayer(config, name="electra")
|
||||
self.generator_predictions = TFElectraGeneratorPredictions(config, name="generator_predictions")
|
||||
|
||||
if isinstance(config.hidden_act, str):
|
||||
self.activation = get_tf_activation(config.hidden_act)
|
||||
else:
|
||||
self.activation = config.hidden_act
|
||||
|
||||
self.generator_lm_head = TFElectraMaskedLMHead(config, self.electra.embeddings, name="generator_lm_head")
|
||||
|
||||
def get_output_embeddings(self):
|
||||
@ -615,8 +867,10 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel, TFMaskedLanguageModelingLos
|
||||
in ``[0, ..., config.vocab_size]``
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.electra.config.return_dict
|
||||
|
||||
if isinstance(input_ids, (tuple, list)):
|
||||
labels = input_ids[9] if len(input_ids) > 9 else labels
|
||||
|
||||
if len(input_ids) > 9:
|
||||
input_ids = input_ids[:9]
|
||||
elif isinstance(input_ids, (dict, BatchEncoding)):
|
||||
@ -637,11 +891,11 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel, TFMaskedLanguageModelingLos
|
||||
generator_sequence_output = generator_hidden_states[0]
|
||||
prediction_scores = self.generator_predictions(generator_sequence_output, training=training)
|
||||
prediction_scores = self.generator_lm_head(prediction_scores, training=training)
|
||||
|
||||
loss = None if labels is None else self.compute_loss(labels, prediction_scores)
|
||||
|
||||
if not return_dict:
|
||||
output = (prediction_scores,) + generator_hidden_states[1:]
|
||||
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TFMaskedLMOutput(
|
||||
@ -657,6 +911,7 @@ class TFElectraClassificationHead(tf.keras.layers.Layer):
|
||||
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.dense = tf.keras.layers.Dense(
|
||||
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
||||
)
|
||||
@ -717,8 +972,10 @@ class TFElectraForSequenceClassification(TFElectraPreTrainedModel, TFSequenceCla
|
||||
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.electra.config.return_dict
|
||||
|
||||
if isinstance(input_ids, (tuple, list)):
|
||||
labels = input_ids[9] if len(input_ids) > 9 else labels
|
||||
|
||||
if len(input_ids) > 9:
|
||||
input_ids = input_ids[:9]
|
||||
elif isinstance(input_ids, (dict, BatchEncoding)):
|
||||
@ -737,11 +994,11 @@ class TFElectraForSequenceClassification(TFElectraPreTrainedModel, TFSequenceCla
|
||||
training=training,
|
||||
)
|
||||
logits = self.classifier(outputs[0])
|
||||
|
||||
loss = None if labels is None else self.compute_loss(labels, logits)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TFSequenceClassifierOutput(
|
||||
@ -831,6 +1088,7 @@ class TFElectraForMultipleChoice(TFElectraPreTrainedModel, TFMultipleChoiceLoss)
|
||||
assert len(inputs) <= 10, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.electra.config.return_dict
|
||||
|
||||
if input_ids is not None:
|
||||
@ -864,11 +1122,11 @@ class TFElectraForMultipleChoice(TFElectraPreTrainedModel, TFMultipleChoiceLoss)
|
||||
logits = self.sequence_summary(outputs[0])
|
||||
logits = self.classifier(logits)
|
||||
reshaped_logits = tf.reshape(logits, (-1, num_choices))
|
||||
|
||||
loss = None if labels is None else self.compute_loss(labels, reshaped_logits)
|
||||
|
||||
if not return_dict:
|
||||
output = (reshaped_logits,) + outputs[1:]
|
||||
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TFMultipleChoiceModelOutput(
|
||||
@ -922,8 +1180,10 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassific
|
||||
Indices should be in ``[0, ..., config.num_labels - 1]``.
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.electra.config.return_dict
|
||||
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
labels = inputs[9] if len(inputs) > 9 else labels
|
||||
|
||||
if len(inputs) > 9:
|
||||
inputs = inputs[:9]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
@ -944,11 +1204,11 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassific
|
||||
discriminator_sequence_output = discriminator_hidden_states[0]
|
||||
discriminator_sequence_output = self.dropout(discriminator_sequence_output)
|
||||
logits = self.classifier(discriminator_sequence_output)
|
||||
|
||||
loss = None if labels is None else self.compute_loss(labels, logits)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + discriminator_hidden_states[1:]
|
||||
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TFTokenClassifierOutput(
|
||||
@ -967,8 +1227,8 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassific
|
||||
class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnsweringLoss):
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.num_labels = config.num_labels
|
||||
self.electra = TFElectraMainLayer(config, name="electra")
|
||||
self.qa_outputs = tf.keras.layers.Dense(
|
||||
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
|
||||
@ -1007,9 +1267,11 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin
|
||||
Position outside of the sequence are not taken into account for computing the loss.
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.electra.config.return_dict
|
||||
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
start_positions = inputs[9] if len(inputs) > 9 else start_positions
|
||||
end_positions = inputs[10] if len(inputs) > 10 else end_positions
|
||||
|
||||
if len(inputs) > 9:
|
||||
inputs = inputs[:9]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
@ -1029,13 +1291,12 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin
|
||||
training=training,
|
||||
)
|
||||
discriminator_sequence_output = discriminator_hidden_states[0]
|
||||
|
||||
logits = self.qa_outputs(discriminator_sequence_output)
|
||||
start_logits, end_logits = tf.split(logits, 2, axis=-1)
|
||||
start_logits = tf.squeeze(start_logits, axis=-1)
|
||||
end_logits = tf.squeeze(end_logits, axis=-1)
|
||||
|
||||
loss = None
|
||||
|
||||
if start_positions is not None and end_positions is not None:
|
||||
labels = {"start_position": start_positions}
|
||||
labels["end_position"] = end_positions
|
||||
@ -1046,6 +1307,7 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin
|
||||
start_logits,
|
||||
end_logits,
|
||||
) + discriminator_hidden_states[1:]
|
||||
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TFQuestionAnsweringModelOutput(
|
||||
|
@ -15,24 +15,23 @@
|
||||
""" TF 2.0 Flaubert model.
|
||||
"""
|
||||
|
||||
import random
|
||||
import itertools
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from transformers.activations_tf import get_tf_activation
|
||||
|
||||
from .configuration_flaubert import FlaubertConfig
|
||||
from .file_utils import add_start_docstrings
|
||||
from .file_utils import ModelOutput, add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
|
||||
from .modeling_tf_outputs import TFBaseModelOutput
|
||||
from .modeling_tf_utils import keras_serializable, shape_list
|
||||
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, get_initializer, keras_serializable, shape_list
|
||||
from .modeling_tf_xlm import (
|
||||
TFXLMForMultipleChoice,
|
||||
TFXLMForQuestionAnsweringSimple,
|
||||
TFXLMForSequenceClassification,
|
||||
TFXLMForTokenClassification,
|
||||
TFXLMMainLayer,
|
||||
TFXLMModel,
|
||||
TFXLMPredLayer,
|
||||
TFXLMWithLMHeadModel,
|
||||
get_masks,
|
||||
)
|
||||
from .tokenization_utils import BatchEncoding
|
||||
from .utils import logging
|
||||
@ -40,6 +39,9 @@ from .utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "FlaubertConfig"
|
||||
_TOKENIZER_FOR_DOC = "FlaubertTokenizer"
|
||||
|
||||
TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
# See all Flaubert models at https://huggingface.co/models?filter=flaubert
|
||||
]
|
||||
@ -155,27 +157,258 @@ FLAUBERT_INPUTS_DOCSTRING = r"""
|
||||
"""
|
||||
|
||||
|
||||
def get_masks(slen, lengths, causal, padding_mask=None, dtype=tf.float32):
|
||||
"""
|
||||
Generate hidden states mask, and optionally an attention mask.
|
||||
"""
|
||||
bs = shape_list(lengths)[0]
|
||||
if padding_mask is not None:
|
||||
mask = padding_mask
|
||||
else:
|
||||
# assert lengths.max().item() <= slen
|
||||
alen = tf.range(slen)
|
||||
mask = tf.math.less(alen, lengths[:, tf.newaxis])
|
||||
|
||||
# attention mask is the same as mask, or triangular inferior attention (causal)
|
||||
if causal:
|
||||
attn_mask = tf.less_equal(
|
||||
tf.tile(alen[tf.newaxis, tf.newaxis, :], (bs, slen, 1)), alen[tf.newaxis, :, tf.newaxis]
|
||||
)
|
||||
else:
|
||||
attn_mask = mask
|
||||
|
||||
# sanity check
|
||||
# assert shape_list(mask) == [bs, slen]
|
||||
tf.debugging.assert_equal(shape_list(mask), [bs, slen])
|
||||
assert causal is False or shape_list(attn_mask) == [bs, slen, slen]
|
||||
|
||||
mask = tf.cast(mask, dtype=dtype)
|
||||
attn_mask = tf.cast(attn_mask, dtype=dtype)
|
||||
|
||||
return mask, attn_mask
|
||||
|
||||
|
||||
class TFFlaubertPreTrainedModel(TFPreTrainedModel):
|
||||
"""An abstract class to handle weights initialization and
|
||||
a simple interface for downloading and loading pretrained models.
|
||||
"""
|
||||
|
||||
config_class = FlaubertConfig
|
||||
base_model_prefix = "transformer"
|
||||
|
||||
@property
|
||||
def dummy_inputs(self):
|
||||
# Sometimes XLM has language embeddings so don't forget to build them as well if needed
|
||||
inputs_list = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])
|
||||
attns_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
|
||||
if self.config.use_lang_emb and self.config.n_langs > 1:
|
||||
langs_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
|
||||
else:
|
||||
langs_list = None
|
||||
return {"input_ids": inputs_list, "attention_mask": attns_list, "langs": langs_list}
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare Flaubert Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
"The bare Flaubert Model transformer outputing raw hidden-states without any specific head on top.",
|
||||
FLAUBERT_START_DOCSTRING,
|
||||
)
|
||||
class TFFlaubertModel(TFXLMModel):
|
||||
config_class = FlaubertConfig
|
||||
|
||||
class TFFlaubertModel(TFFlaubertPreTrainedModel):
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
self.transformer = TFFlaubertMainLayer(config, name="transformer")
|
||||
|
||||
@add_start_docstrings_to_callable(FLAUBERT_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="jplu/tf-flaubert-small-cased",
|
||||
output_type=TFBaseModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def call(self, inputs, **kwargs):
|
||||
outputs = self.transformer(inputs, **kwargs)
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.modeling_tf_xlm.TFXLMMultiHeadAttention with XLM->Flaubert
|
||||
class TFFlaubertMultiHeadAttention(tf.keras.layers.Layer):
|
||||
NEW_ID = itertools.count()
|
||||
|
||||
def __init__(self, n_heads, dim, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.layer_id = next(TFFlaubertMultiHeadAttention.NEW_ID)
|
||||
self.dim = dim
|
||||
self.n_heads = n_heads
|
||||
self.output_attentions = config.output_attentions
|
||||
assert self.dim % self.n_heads == 0
|
||||
|
||||
self.q_lin = tf.keras.layers.Dense(dim, kernel_initializer=get_initializer(config.init_std), name="q_lin")
|
||||
self.k_lin = tf.keras.layers.Dense(dim, kernel_initializer=get_initializer(config.init_std), name="k_lin")
|
||||
self.v_lin = tf.keras.layers.Dense(dim, kernel_initializer=get_initializer(config.init_std), name="v_lin")
|
||||
self.out_lin = tf.keras.layers.Dense(dim, kernel_initializer=get_initializer(config.init_std), name="out_lin")
|
||||
self.dropout = tf.keras.layers.Dropout(config.attention_dropout)
|
||||
self.pruned_heads = set()
|
||||
|
||||
def prune_heads(self, heads):
|
||||
raise NotImplementedError
|
||||
|
||||
def call(self, input, mask, kv, cache, head_mask, output_attentions, training=False):
|
||||
"""
|
||||
Self-attention (if kv is None) or attention over source sentence (provided by kv).
|
||||
"""
|
||||
# Input is (bs, qlen, dim)
|
||||
# Mask is (bs, klen) (non-causal) or (bs, klen, klen)
|
||||
bs, qlen, dim = shape_list(input)
|
||||
|
||||
if kv is None:
|
||||
klen = qlen if cache is None else cache["slen"] + qlen
|
||||
else:
|
||||
klen = shape_list(kv)[1]
|
||||
|
||||
# assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
|
||||
dim_per_head = tf.math.divide(self.dim, self.n_heads)
|
||||
dim_per_head = tf.cast(dim_per_head, dtype=tf.int32)
|
||||
mask_reshape = (bs, 1, qlen, klen) if len(shape_list(mask)) == 3 else (bs, 1, 1, klen)
|
||||
|
||||
def shape(x):
|
||||
""" projection """
|
||||
return tf.transpose(tf.reshape(x, (bs, -1, self.n_heads, dim_per_head)), perm=(0, 2, 1, 3))
|
||||
|
||||
def unshape(x):
|
||||
""" compute context """
|
||||
return tf.reshape(tf.transpose(x, perm=(0, 2, 1, 3)), (bs, -1, self.n_heads * dim_per_head))
|
||||
|
||||
q = shape(self.q_lin(input)) # (bs, n_heads, qlen, dim_per_head)
|
||||
|
||||
if kv is None:
|
||||
k = shape(self.k_lin(input)) # (bs, n_heads, qlen, dim_per_head)
|
||||
v = shape(self.v_lin(input)) # (bs, n_heads, qlen, dim_per_head)
|
||||
elif cache is None or self.layer_id not in cache:
|
||||
k = v = kv
|
||||
k = shape(self.k_lin(k)) # (bs, n_heads, qlen, dim_per_head)
|
||||
v = shape(self.v_lin(v)) # (bs, n_heads, qlen, dim_per_head)
|
||||
|
||||
if cache is not None:
|
||||
if self.layer_id in cache:
|
||||
if kv is None:
|
||||
k_, v_ = cache[self.layer_id]
|
||||
k = tf.concat([k_, k], axis=2) # (bs, n_heads, klen, dim_per_head)
|
||||
v = tf.concat([v_, v], axis=2) # (bs, n_heads, klen, dim_per_head)
|
||||
else:
|
||||
k, v = cache[self.layer_id]
|
||||
|
||||
cache[self.layer_id] = (k, v)
|
||||
|
||||
q = tf.cast(q, dtype=tf.float32)
|
||||
q = tf.multiply(q, tf.math.rsqrt(tf.cast(dim_per_head, dtype=tf.float32))) # (bs, n_heads, qlen, dim_per_head)
|
||||
k = tf.cast(k, dtype=q.dtype)
|
||||
scores = tf.matmul(q, k, transpose_b=True) # (bs, n_heads, qlen, klen)
|
||||
mask = tf.reshape(mask, mask_reshape) # (bs, n_heads, qlen, klen)
|
||||
# scores.masked_fill_(mask, -float('inf')) # (bs, n_heads, qlen, klen)
|
||||
mask = tf.cast(mask, dtype=scores.dtype)
|
||||
scores = scores - 1e30 * (1.0 - mask)
|
||||
weights = tf.nn.softmax(scores, axis=-1) # (bs, n_heads, qlen, klen)
|
||||
weights = self.dropout(weights, training=training) # (bs, n_heads, qlen, klen)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
weights = weights * head_mask
|
||||
|
||||
context = tf.matmul(weights, v) # (bs, n_heads, qlen, dim_per_head)
|
||||
context = unshape(context) # (bs, qlen, dim)
|
||||
outputs = (self.out_lin(context),)
|
||||
|
||||
if output_attentions:
|
||||
outputs = outputs + (weights,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.modeling_tf_xlm.TFXLMTransformerFFN
|
||||
class TFFlaubertTransformerFFN(tf.keras.layers.Layer):
|
||||
def __init__(self, in_dim, dim_hidden, out_dim, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.lin1 = tf.keras.layers.Dense(dim_hidden, kernel_initializer=get_initializer(config.init_std), name="lin1")
|
||||
self.lin2 = tf.keras.layers.Dense(out_dim, kernel_initializer=get_initializer(config.init_std), name="lin2")
|
||||
self.act = get_tf_activation("gelu") if config.gelu_activation else get_tf_activation("relu")
|
||||
self.dropout = tf.keras.layers.Dropout(config.dropout)
|
||||
|
||||
def call(self, input, training=False):
|
||||
x = self.lin1(input)
|
||||
x = self.act(x)
|
||||
x = self.lin2(x)
|
||||
x = self.dropout(x, training=training)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@keras_serializable
|
||||
class TFFlaubertMainLayer(TFXLMMainLayer):
|
||||
class TFFlaubertMainLayer(tf.keras.layers.Layer):
|
||||
config_class = FlaubertConfig
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.n_heads = config.n_heads
|
||||
self.n_langs = config.n_langs
|
||||
self.dim = config.emb_dim
|
||||
self.hidden_dim = self.dim * 4
|
||||
self.n_words = config.n_words
|
||||
self.pad_index = config.pad_index
|
||||
self.causal = config.causal
|
||||
self.n_layers = config.n_layers
|
||||
self.use_lang_emb = config.use_lang_emb
|
||||
self.layerdrop = getattr(config, "layerdrop", 0.0)
|
||||
self.pre_norm = getattr(config, "pre_norm", False)
|
||||
self.output_attentions = config.output_attentions
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
self.return_dict = config.use_return_dict
|
||||
self.dropout = tf.keras.layers.Dropout(config.dropout)
|
||||
self.position_embeddings = tf.keras.layers.Embedding(
|
||||
config.max_position_embeddings,
|
||||
self.dim,
|
||||
embeddings_initializer=get_initializer(config.embed_init_std),
|
||||
name="position_embeddings",
|
||||
)
|
||||
|
||||
if config.n_langs > 1 and config.use_lang_emb:
|
||||
self.lang_embeddings = tf.keras.layers.Embedding(
|
||||
self.n_langs,
|
||||
self.dim,
|
||||
embeddings_initializer=get_initializer(config.embed_init_std),
|
||||
name="lang_embeddings",
|
||||
)
|
||||
|
||||
self.embeddings = TFSharedEmbeddings(
|
||||
self.n_words, self.dim, initializer_range=config.embed_init_std, name="embeddings"
|
||||
)
|
||||
self.layer_norm_emb = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm_emb")
|
||||
self.attentions = []
|
||||
self.layer_norm1 = []
|
||||
self.ffns = []
|
||||
self.layer_norm2 = []
|
||||
|
||||
for i in range(self.n_layers):
|
||||
self.attentions.append(
|
||||
TFFlaubertMultiHeadAttention(self.n_heads, self.dim, config=config, name="attentions_._{}".format(i))
|
||||
)
|
||||
self.layer_norm1.append(
|
||||
tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm1_._{}".format(i))
|
||||
)
|
||||
# if self.is_decoder:
|
||||
# self.layer_norm15.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps))
|
||||
# self.encoder_attn.append(MultiHeadAttention(self.n_heads, self.dim, dropout=self.attention_dropout))
|
||||
self.ffns.append(
|
||||
TFFlaubertTransformerFFN(
|
||||
self.dim, self.hidden_dim, self.dim, config=config, name="ffns_._{}".format(i)
|
||||
)
|
||||
)
|
||||
self.layer_norm2.append(
|
||||
tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm2_._{}".format(i))
|
||||
)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings
|
||||
|
||||
def call(
|
||||
self,
|
||||
@ -305,21 +538,26 @@ class TFFlaubertMainLayer(TFXLMMainLayer):
|
||||
inputs_embeds = self.embeddings(input_ids)
|
||||
|
||||
tensor = inputs_embeds + self.position_embeddings(position_ids)
|
||||
|
||||
if langs is not None and self.use_lang_emb:
|
||||
tensor = tensor + self.lang_embeddings(langs)
|
||||
if token_type_ids is not None:
|
||||
tensor = tensor + self.embeddings(token_type_ids)
|
||||
|
||||
tensor = self.layer_norm_emb(tensor)
|
||||
tensor = self.dropout(tensor, training=training)
|
||||
tensor = tensor * mask[..., tf.newaxis]
|
||||
|
||||
# hidden_states and attentions cannot be None in graph mode.
|
||||
hidden_states = ()
|
||||
attentions = ()
|
||||
|
||||
# transformer layers
|
||||
hidden_states = () if output_hidden_states else None
|
||||
attentions = () if output_attentions else None
|
||||
for i in range(self.n_layers):
|
||||
# LayerDrop
|
||||
dropout_probability = random.uniform(0, 1)
|
||||
if training and (dropout_probability < self.layerdrop):
|
||||
dropout_probability = tf.random.uniform([1], 0, 1)
|
||||
|
||||
if training and tf.less(dropout_probability, self.layerdrop):
|
||||
continue
|
||||
|
||||
if output_hidden_states:
|
||||
@ -331,8 +569,10 @@ class TFFlaubertMainLayer(TFXLMMainLayer):
|
||||
tensor, attn_mask, None, cache, head_mask[i], output_attentions, training=training
|
||||
)
|
||||
attn = attn_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
attentions = attentions + (attn_outputs[1],)
|
||||
|
||||
attn = self.dropout(attn, training=training)
|
||||
tensor = tensor + attn
|
||||
tensor = self.layer_norm1[i](tensor)
|
||||
@ -342,8 +582,10 @@ class TFFlaubertMainLayer(TFXLMMainLayer):
|
||||
tensor_normalized, attn_mask, None, cache, head_mask[i], output_attentions, training=training
|
||||
)
|
||||
attn = attn_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
attentions = attentions + (attn_outputs[1],)
|
||||
|
||||
attn = self.dropout(attn, training=training)
|
||||
tensor = tensor + attn
|
||||
|
||||
@ -375,23 +617,129 @@ class TFFlaubertMainLayer(TFXLMMainLayer):
|
||||
# move back sequence length to dimension 0
|
||||
# tensor = tensor.transpose(0, 1)
|
||||
|
||||
# Set to None here if the output booleans are at False
|
||||
hidden_states = hidden_states if output_hidden_states else None
|
||||
attentions = attentions if output_attentions else None
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [tensor, hidden_states, attentions] if v is not None)
|
||||
|
||||
return TFBaseModelOutput(last_hidden_state=tensor, hidden_states=hidden_states, attentions=attentions)
|
||||
|
||||
|
||||
# Copied from transformers.modeling_tf_xlm.TFXLMPredLayer
|
||||
class TFFlaubertPredLayer(tf.keras.layers.Layer):
|
||||
"""
|
||||
Prediction layer (cross_entropy or adaptive_softmax).
|
||||
"""
|
||||
|
||||
def __init__(self, config, input_embeddings, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.asm = config.asm
|
||||
self.n_words = config.n_words
|
||||
self.pad_index = config.pad_index
|
||||
|
||||
if config.asm is False:
|
||||
self.input_embeddings = input_embeddings
|
||||
else:
|
||||
raise NotImplementedError
|
||||
# self.proj = nn.AdaptiveLogSoftmaxWithLoss(
|
||||
# in_features=dim,
|
||||
# n_classes=config.n_words,
|
||||
# cutoffs=config.asm_cutoffs,
|
||||
# div_value=config.asm_div_value,
|
||||
# head_bias=True, # default is False
|
||||
# )
|
||||
|
||||
def build(self, input_shape):
|
||||
# The output weights are the same as the input embeddings, but there is an output-only bias for each token.
|
||||
self.bias = self.add_weight(shape=(self.n_words,), initializer="zeros", trainable=True, name="bias")
|
||||
|
||||
super().build(input_shape)
|
||||
|
||||
def call(self, hidden_states):
|
||||
hidden_states = self.input_embeddings(hidden_states, mode="linear")
|
||||
hidden_states = hidden_states + self.bias
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@dataclass
|
||||
class TFFlaubertWithLMHeadModelOutput(ModelOutput):
|
||||
"""
|
||||
Base class for :class:`~transformers.TFFlaubertWithLMHeadModel` outputs.
|
||||
|
||||
Args:
|
||||
logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`tf.Tensor` (one for each layer) of shape
|
||||
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
"""
|
||||
|
||||
logits: tf.Tensor = None
|
||||
hidden_states: Optional[Tuple[tf.Tensor]] = None
|
||||
attentions: Optional[Tuple[tf.Tensor]] = None
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""The Flaubert Model transformer with a language modeling head on top
|
||||
(linear layer with weights tied to the input embeddings). """,
|
||||
FLAUBERT_START_DOCSTRING,
|
||||
)
|
||||
class TFFlaubertWithLMHeadModel(TFXLMWithLMHeadModel):
|
||||
config_class = FlaubertConfig
|
||||
|
||||
class TFFlaubertWithLMHeadModel(TFFlaubertPreTrainedModel):
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
self.transformer = TFFlaubertMainLayer(config, name="transformer")
|
||||
self.pred_layer = TFXLMPredLayer(config, self.transformer.embeddings, name="pred_layer_._proj")
|
||||
self.pred_layer = TFFlaubertPredLayer(config, self.transformer.embeddings, name="pred_layer_._proj")
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.pred_layer.input_embeddings
|
||||
|
||||
def prepare_inputs_for_generation(self, inputs, **kwargs):
|
||||
mask_token_id = self.config.mask_token_id
|
||||
lang_id = self.config.lang_id
|
||||
|
||||
effective_batch_size = inputs.shape[0]
|
||||
mask_token = tf.ones((effective_batch_size, 1), dtype=tf.int32) * mask_token_id
|
||||
inputs = tf.concat([inputs, mask_token], axis=1)
|
||||
|
||||
if lang_id is not None:
|
||||
langs = tf.ones_like(inputs) * lang_id
|
||||
else:
|
||||
langs = None
|
||||
return {"inputs": inputs, "langs": langs}
|
||||
|
||||
@add_start_docstrings_to_callable(FLAUBERT_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="jplu/tf-flaubert-small-cased",
|
||||
output_type=TFFlaubertWithLMHeadModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def call(self, inputs, **kwargs):
|
||||
return_dict = kwargs.get("return_dict")
|
||||
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
|
||||
transformer_outputs = self.transformer(inputs, **kwargs)
|
||||
|
||||
output = transformer_outputs[0]
|
||||
outputs = self.pred_layer(output)
|
||||
|
||||
if not return_dict:
|
||||
return (outputs,) + transformer_outputs[1:]
|
||||
|
||||
return TFFlaubertWithLMHeadModelOutput(
|
||||
logits=outputs, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
|
@ -16,16 +16,16 @@
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from transformers.activations_tf import get_tf_activation
|
||||
|
||||
from .configuration_longformer import LongformerConfig
|
||||
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
|
||||
from .modeling_tf_bert import TFBertIntermediate, TFBertOutput, TFBertPooler, TFBertSelfOutput
|
||||
from .modeling_tf_outputs import (
|
||||
TFBaseModelOutput,
|
||||
TFBaseModelOutputWithPooling,
|
||||
TFMaskedLMOutput,
|
||||
TFQuestionAnsweringModelOutput,
|
||||
)
|
||||
from .modeling_tf_roberta import TFRobertaEmbeddings, TFRobertaLMHead
|
||||
from .modeling_tf_utils import (
|
||||
TFMaskedLanguageModelingLoss,
|
||||
TFPreTrainedModel,
|
||||
@ -84,18 +84,280 @@ def _compute_global_attention_mask(input_ids_shape, sep_token_indices, before_se
|
||||
return attention_mask
|
||||
|
||||
|
||||
# Copied from transformers.modeling_tf_roberta.TFRobertaLMHead
|
||||
class TFLongformerLMHead(tf.keras.layers.Layer):
|
||||
"""Roberta Head for masked language modeling."""
|
||||
|
||||
def __init__(self, config, input_embeddings, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.vocab_size = config.vocab_size
|
||||
self.dense = tf.keras.layers.Dense(
|
||||
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
||||
)
|
||||
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
|
||||
self.act = get_tf_activation("gelu")
|
||||
|
||||
# The output weights are the same as the input embeddings, but there is
|
||||
# an output-only bias for each token.
|
||||
self.decoder = input_embeddings
|
||||
|
||||
def build(self, input_shape):
|
||||
self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias")
|
||||
|
||||
super().build(input_shape)
|
||||
|
||||
def call(self, features):
|
||||
x = self.dense(features)
|
||||
x = self.act(x)
|
||||
x = self.layer_norm(x)
|
||||
|
||||
# project back to size of vocabulary with bias
|
||||
x = self.decoder(x, mode="linear") + self.bias
|
||||
|
||||
return x
|
||||
|
||||
|
||||
# Copied from transformers.modeling_tf_roberta.TFRobertaEmbeddings
|
||||
class TFLongformerEmbeddings(tf.keras.layers.Layer):
|
||||
"""
|
||||
Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
|
||||
"""
|
||||
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(config, **kwargs)
|
||||
|
||||
self.padding_idx = 1
|
||||
self.vocab_size = config.vocab_size
|
||||
self.hidden_size = config.hidden_size
|
||||
self.initializer_range = config.initializer_range
|
||||
self.position_embeddings = tf.keras.layers.Embedding(
|
||||
config.max_position_embeddings,
|
||||
config.hidden_size,
|
||||
embeddings_initializer=get_initializer(self.initializer_range),
|
||||
name="position_embeddings",
|
||||
)
|
||||
self.token_type_embeddings = tf.keras.layers.Embedding(
|
||||
config.type_vocab_size,
|
||||
config.hidden_size,
|
||||
embeddings_initializer=get_initializer(self.initializer_range),
|
||||
name="token_type_embeddings",
|
||||
)
|
||||
|
||||
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
||||
# any TensorFlow checkpoint file
|
||||
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
|
||||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def build(self, input_shape):
|
||||
"""Build shared word embedding layer """
|
||||
with tf.name_scope("word_embeddings"):
|
||||
# Create and initialize weights. The random normal initializer was chosen
|
||||
# arbitrarily, and works well.
|
||||
self.word_embeddings = self.add_weight(
|
||||
"weight",
|
||||
shape=[self.vocab_size, self.hidden_size],
|
||||
initializer=get_initializer(self.initializer_range),
|
||||
)
|
||||
|
||||
super().build(input_shape)
|
||||
|
||||
def create_position_ids_from_input_ids(self, x):
|
||||
"""Replace non-padding symbols with their position numbers. Position numbers begin at
|
||||
padding_idx+1. Padding symbols are ignored. This is modified from fairseq's
|
||||
`utils.make_positions`.
|
||||
:param tf.Tensor x:
|
||||
:return tf.Tensor:
|
||||
"""
|
||||
mask = tf.cast(tf.math.not_equal(x, self.padding_idx), dtype=tf.int32)
|
||||
incremental_indicies = tf.math.cumsum(mask, axis=1) * mask
|
||||
|
||||
return incremental_indicies + self.padding_idx
|
||||
|
||||
def create_position_ids_from_inputs_embeds(self, inputs_embeds):
|
||||
"""We are provided embeddings directly. We cannot infer which are padded so just generate
|
||||
sequential position ids.
|
||||
:param tf.Tensor inputs_embeds:
|
||||
:return tf.Tensor:
|
||||
"""
|
||||
seq_length = shape_list(inputs_embeds)[1]
|
||||
position_ids = tf.range(self.padding_idx + 1, seq_length + self.padding_idx + 1, dtype=tf.int32)[tf.newaxis, :]
|
||||
|
||||
return position_ids
|
||||
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
position_ids=None,
|
||||
token_type_ids=None,
|
||||
inputs_embeds=None,
|
||||
mode="embedding",
|
||||
training=False,
|
||||
):
|
||||
"""Get token embeddings of inputs.
|
||||
Args:
|
||||
inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
|
||||
mode: string, a valid value is one of "embedding" and "linear".
|
||||
Returns:
|
||||
outputs: (1) If mode == "embedding", output embedding tensor, float32 with
|
||||
shape [batch_size, length, embedding_size]; (2) mode == "linear", output
|
||||
linear tensor, float32 with shape [batch_size, length, vocab_size].
|
||||
Raises:
|
||||
ValueError: if mode is not valid.
|
||||
|
||||
Shared weights logic adapted from
|
||||
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
|
||||
"""
|
||||
if mode == "embedding":
|
||||
return self._embedding(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
|
||||
elif mode == "linear":
|
||||
return self._linear(input_ids)
|
||||
else:
|
||||
raise ValueError("mode {} is not valid.".format(mode))
|
||||
|
||||
def _embedding(self, input_ids, position_ids, token_type_ids, inputs_embeds, training=False):
|
||||
"""Applies embedding based on inputs tensor."""
|
||||
assert not (input_ids is None and inputs_embeds is None)
|
||||
|
||||
if position_ids is None:
|
||||
if input_ids is not None:
|
||||
# Create the position ids from the input token ids. Any padded tokens remain padded.
|
||||
position_ids = self.create_position_ids_from_input_ids(input_ids)
|
||||
else:
|
||||
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
|
||||
|
||||
if input_ids is not None:
|
||||
input_shape = shape_list(input_ids)
|
||||
else:
|
||||
input_shape = shape_list(inputs_embeds)[:-1]
|
||||
|
||||
seq_length = input_shape[1]
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = tf.range(seq_length, dtype=tf.int32)[tf.newaxis, :]
|
||||
|
||||
if token_type_ids is None:
|
||||
token_type_ids = tf.fill(input_shape, 0)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = tf.gather(self.word_embeddings, input_ids)
|
||||
|
||||
position_embeddings = tf.cast(self.position_embeddings(position_ids), inputs_embeds.dtype)
|
||||
token_type_embeddings = tf.cast(self.token_type_embeddings(token_type_ids), inputs_embeds.dtype)
|
||||
embeddings = inputs_embeds + position_embeddings + token_type_embeddings
|
||||
embeddings = self.LayerNorm(embeddings)
|
||||
embeddings = self.dropout(embeddings, training=training)
|
||||
|
||||
return embeddings
|
||||
|
||||
def _linear(self, inputs):
|
||||
"""Computes logits by running inputs through a linear layer.
|
||||
Args:
|
||||
inputs: A float32 tensor with shape [batch_size, length, hidden_size]
|
||||
Returns:
|
||||
float32 tensor with shape [batch_size, length, vocab_size].
|
||||
"""
|
||||
batch_size = shape_list(inputs)[0]
|
||||
length = shape_list(inputs)[1]
|
||||
x = tf.reshape(inputs, [-1, self.hidden_size])
|
||||
logits = tf.matmul(x, self.word_embeddings, transpose_b=True)
|
||||
|
||||
return tf.reshape(logits, [batch_size, length, self.vocab_size])
|
||||
|
||||
|
||||
# Copied from transformers.modeling_tf_bert.TFBertIntermediate
|
||||
class TFLongformerIntermediate(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.dense = tf.keras.layers.Dense(
|
||||
config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
||||
)
|
||||
|
||||
if isinstance(config.hidden_act, str):
|
||||
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
|
||||
else:
|
||||
self.intermediate_act_fn = config.hidden_act
|
||||
|
||||
def call(self, hidden_states):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.modeling_tf_bert.TFBertOutput
|
||||
class TFLongformerOutput(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.dense = tf.keras.layers.Dense(
|
||||
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
||||
)
|
||||
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
|
||||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def call(self, hidden_states, input_tensor, training=False):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.modeling_tf_bert.TFBertPooler
|
||||
class TFLongformerPooler(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.dense = tf.keras.layers.Dense(
|
||||
config.hidden_size,
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
activation="tanh",
|
||||
name="dense",
|
||||
)
|
||||
|
||||
def call(self, hidden_states):
|
||||
# We "pool" the model by simply taking the hidden state corresponding
|
||||
# to the first token.
|
||||
first_token_tensor = hidden_states[:, 0]
|
||||
pooled_output = self.dense(first_token_tensor)
|
||||
|
||||
return pooled_output
|
||||
|
||||
|
||||
# Copied from transformers.modeling_tf_bert.TFBertSelfOutput
|
||||
class TFLongformerSelfOutput(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.dense = tf.keras.layers.Dense(
|
||||
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
||||
)
|
||||
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
|
||||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def call(self, hidden_states, input_tensor, training=False):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
def __init__(self, config, layer_id, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if config.hidden_size % config.num_attention_heads != 0:
|
||||
raise ValueError(
|
||||
"The hidden size (%d) is not a multiple of the number of attention "
|
||||
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
|
||||
)
|
||||
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = int(config.hidden_size / config.num_attention_heads)
|
||||
self.embed_dim = config.hidden_size
|
||||
|
||||
self.query = tf.keras.layers.Dense(
|
||||
self.embed_dim,
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
@ -128,13 +390,11 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
name="value_global",
|
||||
)
|
||||
|
||||
self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)
|
||||
self.global_dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
self.layer_id = layer_id
|
||||
|
||||
attention_window = config.attention_window[self.layer_id]
|
||||
|
||||
assert (
|
||||
attention_window % 2 == 0
|
||||
), f"`attention_window` for layer {self.layer_id} has to be an even value. Given {attention_window}"
|
||||
@ -173,8 +433,8 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
query_vectors = self.query(hidden_states)
|
||||
key_vectors = self.key(hidden_states)
|
||||
value_vectors = self.value(hidden_states)
|
||||
|
||||
batch_size, seq_len, embed_dim = shape_list(hidden_states)
|
||||
|
||||
tf.debugging.assert_equal(
|
||||
embed_dim,
|
||||
self.embed_dim,
|
||||
@ -183,7 +443,6 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
|
||||
# normalize query
|
||||
query_vectors /= tf.math.sqrt(tf.constant(self.head_dim, dtype=tf.dtypes.float32))
|
||||
|
||||
query_vectors = tf.reshape(query_vectors, (batch_size, seq_len, self.num_heads, self.head_dim))
|
||||
key_vectors = tf.reshape(key_vectors, (batch_size, seq_len, self.num_heads, self.head_dim))
|
||||
|
||||
@ -217,7 +476,6 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
) = self._get_global_attn_indices(is_index_global_attn)
|
||||
|
||||
# this function is only relevant for global attention
|
||||
|
||||
attn_scores = tf.cond(
|
||||
is_global_attn,
|
||||
lambda: self._concat_with_global_key_attn_probs(
|
||||
@ -243,7 +501,6 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
|
||||
# apply dropout
|
||||
attn_probs = self.dropout(attn_probs, training=training)
|
||||
|
||||
value_vectors = tf.reshape(value_vectors, (batch_size, seq_len, self.num_heads, self.head_dim))
|
||||
|
||||
# if global attention, compute sum of global and local attn
|
||||
@ -266,6 +523,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
[batch_size, seq_len, self.num_heads, self.head_dim],
|
||||
message="Unexpected size",
|
||||
)
|
||||
|
||||
attn_output = tf.reshape(attn_output, (batch_size, seq_len, embed_dim))
|
||||
|
||||
# compute value for global attention and overwrite to attention output
|
||||
@ -303,6 +561,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
)
|
||||
|
||||
outputs = (attn_output, attn_probs)
|
||||
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
@ -322,6 +581,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
This implementation splits the input into overlapping chunks of size 2w (e.g. 512 for pretrained Longformer)
|
||||
with an overlap of size window_overlap"""
|
||||
batch_size, seq_len, num_heads, head_dim = shape_list(query)
|
||||
|
||||
tf.debugging.assert_equal(
|
||||
seq_len % (window_overlap * 2),
|
||||
0,
|
||||
@ -341,7 +601,6 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
(batch_size * num_heads, seq_len, head_dim),
|
||||
)
|
||||
key = tf.reshape(tf.transpose(key, (0, 2, 1, 3)), (batch_size * num_heads, seq_len, head_dim))
|
||||
|
||||
chunked_query = self._chunk(query, window_overlap)
|
||||
chunked_key = self._chunk(key, window_overlap)
|
||||
|
||||
@ -390,7 +649,6 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
|
||||
first_chunk_mask = (
|
||||
tf.broadcast_to(
|
||||
tf.range(chunks_count + 1)[None, :, None, None],
|
||||
@ -403,7 +661,6 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
)
|
||||
< 1
|
||||
)
|
||||
|
||||
diagonal_attn_scores_low_triang = tf.where(
|
||||
first_chunk_mask,
|
||||
diagonal_attn_scores_first_chunk,
|
||||
@ -425,6 +682,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
)
|
||||
|
||||
diagonal_attention_scores = self._mask_invalid_locations(diagonal_attention_scores, window_overlap)
|
||||
|
||||
return diagonal_attention_scores
|
||||
|
||||
@staticmethod
|
||||
@ -434,6 +692,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
tf.linalg.band_part(tf.ones(shape=(window_overlap, window_overlap + 1)), -1, 0),
|
||||
axis=[0],
|
||||
)
|
||||
|
||||
# pad to full matrix
|
||||
padding = tf.constant(
|
||||
[[0, shape_list(input_tensor)[1] - window_overlap], [0, shape_list(input_tensor)[3] - window_overlap - 1]]
|
||||
@ -441,6 +700,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
|
||||
# create lower mask
|
||||
mask_2d = tf.pad(mask_2d_upper, padding)
|
||||
|
||||
# combine with upper mask
|
||||
mask_2d = mask_2d + tf.reverse(mask_2d, axis=[0, 1])
|
||||
|
||||
@ -456,7 +716,6 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
return input_tensor
|
||||
|
||||
def _sliding_chunks_matmul_attn_probs_value(self, attn_probs, value, window_overlap):
|
||||
|
||||
"""Same as _sliding_chunks_query_key_matmul but for attn_probs and value tensors.
|
||||
Returned tensor will be of the same shape as `attn_probs`"""
|
||||
|
||||
@ -479,8 +738,8 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
)
|
||||
|
||||
chunks_count = seq_len // window_overlap - 1
|
||||
# group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size 2 window overlap
|
||||
|
||||
# group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size 2 window overlap
|
||||
chunked_attn_probs = tf.reshape(
|
||||
tf.transpose(attn_probs, (0, 2, 1, 3)),
|
||||
(
|
||||
@ -498,15 +757,12 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
)
|
||||
|
||||
# pad seq_len with w at the beginning of the sequence and another window overlap at the end
|
||||
|
||||
paddings = tf.constant([[0, 0], [window_overlap, window_overlap], [0, 0]], dtype=tf.dtypes.int32)
|
||||
padded_value = tf.pad(value, paddings, constant_values=-1)
|
||||
|
||||
# chunk padded_value into chunks of size 3 window overlap and an overlap of size window overlap
|
||||
|
||||
frame_size = 3 * window_overlap * head_dim
|
||||
frame_hop_size = (shape_list(padded_value)[1] * head_dim - frame_size) // chunks_count
|
||||
|
||||
chunked_value = tf.signal.frame(
|
||||
tf.reshape(padded_value, (batch_size * num_heads, -1)),
|
||||
frame_size,
|
||||
@ -524,12 +780,12 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
)
|
||||
|
||||
chunked_attn_probs = self._pad_and_diagonalize(chunked_attn_probs)
|
||||
|
||||
context = tf.einsum("bcwd,bcdh->bcwh", chunked_attn_probs, chunked_value)
|
||||
context = tf.transpose(
|
||||
tf.reshape(context, (batch_size, num_heads, seq_len, head_dim)),
|
||||
(0, 2, 1, 3),
|
||||
)
|
||||
|
||||
return context
|
||||
|
||||
@staticmethod
|
||||
@ -538,7 +794,6 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
hidden_states_padded = tf.pad(
|
||||
hidden_states_padded, paddings
|
||||
) # padding value is not important because it will be overwritten
|
||||
|
||||
batch_size, chunk_size, seq_length, hidden_dim = shape_list(hidden_states_padded)
|
||||
hidden_states_padded = tf.reshape(hidden_states_padded, (batch_size, chunk_size, hidden_dim, seq_length))
|
||||
|
||||
@ -560,12 +815,10 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629 ]
|
||||
"""
|
||||
total_num_heads, num_chunks, window_overlap, hidden_dim = shape_list(chunked_hidden_states)
|
||||
|
||||
paddings = tf.constant([[0, 0], [0, 0], [0, 0], [0, window_overlap + 1]])
|
||||
chunked_hidden_states = tf.pad(
|
||||
chunked_hidden_states, paddings
|
||||
) # total_num_heads x num_chunks x window_overlap x (hidden_dim+window_overlap+1). Padding value is not important because it'll be overwritten
|
||||
|
||||
chunked_hidden_states = tf.reshape(
|
||||
chunked_hidden_states, (total_num_heads, num_chunks, -1)
|
||||
) # total_num_heads x num_chunks x window_overlapL+window_overlapwindow_overlap+window_overlap
|
||||
@ -577,6 +830,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
(total_num_heads, num_chunks, window_overlap, window_overlap + hidden_dim),
|
||||
) # total_num_heads x num_chunks, window_overlap x hidden_dim+window_overlap
|
||||
chunked_hidden_states = chunked_hidden_states[:, :, :, :-1]
|
||||
|
||||
return chunked_hidden_states
|
||||
|
||||
@staticmethod
|
||||
@ -588,7 +842,6 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
# define frame size and frame stride (similar to convolution)
|
||||
frame_hop_size = window_overlap * hidden_dim
|
||||
frame_size = 2 * frame_hop_size
|
||||
|
||||
hidden_states = tf.reshape(hidden_states, (batch_size, seq_length * hidden_dim))
|
||||
|
||||
# chunk with overlap
|
||||
@ -651,6 +904,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
|
||||
# select global key vectors
|
||||
global_key_vectors = tf.gather_nd(key_vectors, is_index_global_attn_nonzero)
|
||||
|
||||
# create only global key vectors
|
||||
key_vectors_only_global = tf.scatter_nd(
|
||||
is_local_index_global_attn_nonzero,
|
||||
@ -665,6 +919,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
|
||||
# (batch_size, seq_len, num_heads, max_num_global_attn_indices)
|
||||
attn_probs_from_global_key = tf.einsum("blhd,bshd->blhs", query_vectors, key_vectors_only_global)
|
||||
|
||||
# (batch_size, max_num_global_attn_indices, seq_len, num_heads)
|
||||
attn_probs_from_global_key_trans = tf.transpose(attn_probs_from_global_key, (0, 3, 1, 2))
|
||||
mask_shape = (shape_list(is_local_index_no_global_attn_nonzero)[0],) + tuple(
|
||||
@ -703,6 +958,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
|
||||
# select global value vectors
|
||||
global_value_vectors = tf.gather_nd(value_vectors, is_index_global_attn_nonzero)
|
||||
|
||||
# create only global value vectors
|
||||
value_vectors_only_global = tf.scatter_nd(
|
||||
is_local_index_global_attn_nonzero,
|
||||
@ -725,6 +981,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
attn_output_without_global = self._sliding_chunks_matmul_attn_probs_value(
|
||||
attn_probs_without_global, value_vectors, self.one_sided_attn_window_size
|
||||
)
|
||||
|
||||
return attn_output_only_global + attn_output_without_global
|
||||
|
||||
def _compute_global_attn_output_from_hidden(
|
||||
@ -755,7 +1012,6 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
|
||||
# normalize
|
||||
global_query_vectors_only_global /= tf.math.sqrt(tf.constant(self.head_dim, dtype=tf.dtypes.float32))
|
||||
|
||||
global_query_vectors_only_global = self.reshape_and_transpose(global_query_vectors_only_global, batch_size)
|
||||
global_key_vectors = self.reshape_and_transpose(global_key_vectors, batch_size)
|
||||
global_value_vectors = self.reshape_and_transpose(global_value_vectors, batch_size)
|
||||
@ -773,7 +1029,6 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
global_attn_scores,
|
||||
(batch_size, self.num_heads, max_num_global_attn_indices, seq_len),
|
||||
)
|
||||
|
||||
global_attn_scores_trans = tf.transpose(global_attn_scores, (0, 2, 1, 3))
|
||||
mask_shape = (shape_list(is_local_index_no_global_attn_nonzero)[0],) + tuple(
|
||||
shape_list(global_attn_scores_trans)[-2:]
|
||||
@ -791,7 +1046,6 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
# mask global attn scores
|
||||
attn_mask = tf.broadcast_to(is_index_masked[:, None, None, :], shape_list(global_attn_scores))
|
||||
global_attn_scores = tf.where(attn_mask, -10000.0, global_attn_scores)
|
||||
|
||||
global_attn_scores = tf.reshape(
|
||||
global_attn_scores,
|
||||
(batch_size * self.num_heads, max_num_global_attn_indices, seq_len),
|
||||
@ -828,10 +1082,10 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
)
|
||||
|
||||
# overwrite values with global attention
|
||||
|
||||
attn_output = tf.tensor_scatter_nd_update(
|
||||
attn_output, is_index_global_attn_nonzero, nonzero_global_attn_output
|
||||
)
|
||||
|
||||
return attn_output
|
||||
|
||||
def reshape_and_transpose(self, vector, batch_size):
|
||||
@ -847,8 +1101,9 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
class TFLongformerAttention(tf.keras.layers.Layer):
|
||||
def __init__(self, config, layer_id=0, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.self_attention = TFLongformerSelfAttention(config, layer_id, name="self")
|
||||
self.dense_output = TFBertSelfOutput(config, name="output")
|
||||
self.dense_output = TFLongformerSelfOutput(config, name="output")
|
||||
|
||||
def prune_heads(self, heads):
|
||||
raise NotImplementedError
|
||||
@ -868,17 +1123,18 @@ class TFLongformerAttention(tf.keras.layers.Layer):
|
||||
training=training,
|
||||
)
|
||||
attention_output = self.dense_output(self_outputs[0], hidden_states, training=training)
|
||||
|
||||
outputs = (attention_output,) + self_outputs[1:]
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class TFLongformerLayer(tf.keras.layers.Layer):
|
||||
def __init__(self, config, layer_id=0, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.attention = TFLongformerAttention(config, layer_id, name="attention")
|
||||
self.intermediate = TFBertIntermediate(config, name="intermediate")
|
||||
self.longformer_output = TFBertOutput(config, name="output")
|
||||
self.intermediate = TFLongformerIntermediate(config, name="intermediate")
|
||||
self.longformer_output = TFLongformerOutput(config, name="output")
|
||||
|
||||
def call(self, inputs, training=False):
|
||||
(
|
||||
@ -898,12 +1154,14 @@ class TFLongformerLayer(tf.keras.layers.Layer):
|
||||
intermediate_output = self.intermediate(attention_output)
|
||||
layer_output = self.longformer_output(intermediate_output, attention_output, training=training)
|
||||
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class TFLongformerEncoder(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
self.output_attentions = config.output_attentions
|
||||
self.layer = [
|
||||
@ -926,6 +1184,7 @@ class TFLongformerEncoder(tf.keras.layers.Layer):
|
||||
):
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_attentions = () if output_attentions else None
|
||||
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if output_hidden_states:
|
||||
hidden_states_to_add = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states
|
||||
@ -954,6 +1213,7 @@ class TFLongformerEncoder(tf.keras.layers.Layer):
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
|
||||
|
||||
return TFBaseModelOutput(
|
||||
last_hidden_state=hidden_states,
|
||||
hidden_states=all_hidden_states,
|
||||
@ -985,10 +1245,9 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
|
||||
self.return_dict = config.use_return_dict
|
||||
self.pad_token_id = config.pad_token_id
|
||||
self.attention_window = config.attention_window
|
||||
|
||||
self.embeddings = TFRobertaEmbeddings(config, name="embeddings")
|
||||
self.embeddings = TFLongformerEmbeddings(config, name="embeddings")
|
||||
self.encoder = TFLongformerEncoder(config, name="encoder")
|
||||
self.pooler = TFBertPooler(config, name="pooler")
|
||||
self.pooler = TFLongformerPooler(config, name="pooler")
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings
|
||||
@ -1084,6 +1343,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
|
||||
is_index_masked = tf.math.less(attention_mask, 1)
|
||||
is_index_global_attn = tf.math.greater(attention_mask, 1)
|
||||
is_global_attn = tf.math.reduce_any(is_index_global_attn)
|
||||
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
# Sizes are [batch_size, to_seq_length, 1, 1]
|
||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
@ -1097,7 +1357,6 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
extended_attention_mask = tf.cast(tf.math.abs(1 - extended_attention_mask), tf.dtypes.float32) * -10000.0
|
||||
|
||||
embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
@ -1111,7 +1370,6 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
)
|
||||
|
||||
sequence_output = encoder_outputs[0]
|
||||
pooled_output = self.pooler(sequence_output)
|
||||
|
||||
@ -1149,22 +1407,27 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
|
||||
)
|
||||
|
||||
assert attention_window % 2 == 0, f"`attention_window` should be an even value. Given {attention_window}"
|
||||
|
||||
input_shape = shape_list(input_ids) if input_ids is not None else shape_list(inputs_embeds)
|
||||
batch_size, seq_len = input_shape[:2]
|
||||
|
||||
padding_len = (attention_window - seq_len % attention_window) % attention_window
|
||||
|
||||
if padding_len > 0:
|
||||
logger.info(
|
||||
"Input ids are automatically padded from {} to {} to be a multiple of `config.attention_window`: {}".format(
|
||||
seq_len, seq_len + padding_len, attention_window
|
||||
)
|
||||
)
|
||||
|
||||
paddings = tf.constant([[0, 0], [0, padding_len]])
|
||||
|
||||
if input_ids is not None:
|
||||
input_ids = tf.pad(input_ids, paddings, constant_values=pad_token_id)
|
||||
|
||||
if position_ids is not None:
|
||||
# pad with position_id = pad_token_id as in modeling_roberta.RobertaEmbeddings
|
||||
position_ids = tf.pad(position_ids, paddings, constant_values=pad_token_id)
|
||||
|
||||
if inputs_embeds is not None:
|
||||
input_ids_padding = tf.fill((batch_size, padding_len), self.pad_token_id)
|
||||
inputs_embeds_padding = self.embeddings(input_ids_padding)
|
||||
@ -1195,6 +1458,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
|
||||
# simply use `global_attention_mask` as `attention_mask`
|
||||
# if no `attention_mask` is given
|
||||
attention_mask = global_attention_mask + 1
|
||||
|
||||
return attention_mask
|
||||
|
||||
|
||||
@ -1339,11 +1603,13 @@ class TFLongformerModel(TFLongformerPreTrainedModel):
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
|
||||
self.longformer = TFLongformerMainLayer(config, name="longformer")
|
||||
|
||||
@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
def call(self, inputs, **kwargs):
|
||||
outputs = self.longformer(inputs, **kwargs)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@ -1356,7 +1622,7 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
|
||||
self.longformer = TFLongformerMainLayer(config, name="longformer")
|
||||
self.lm_head = TFRobertaLMHead(config, self.longformer.embeddings, name="lm_head")
|
||||
self.lm_head = TFLongformerLMHead(config, self.longformer.embeddings, name="lm_head")
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head.decoder
|
||||
@ -1390,8 +1656,10 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel
|
||||
in ``[0, ..., config.vocab_size]``
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.longformer.return_dict
|
||||
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
labels = inputs[9] if len(inputs) > 9 else labels
|
||||
|
||||
if len(inputs) > 9:
|
||||
inputs = inputs[:9]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
@ -1409,14 +1677,13 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
prediction_scores = self.lm_head(sequence_output, training=training)
|
||||
|
||||
loss = None if labels is None else self.compute_loss(labels, prediction_scores)
|
||||
|
||||
if not return_dict:
|
||||
output = (prediction_scores,) + outputs[2:]
|
||||
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TFMaskedLMOutput(
|
||||
@ -1435,8 +1702,8 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel
|
||||
class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAnsweringLoss):
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.num_labels = config.num_labels
|
||||
self.longformer = TFLongformerMainLayer(config, name="longformer")
|
||||
self.qa_outputs = tf.keras.layers.Dense(
|
||||
config.num_labels,
|
||||
@ -1477,6 +1744,7 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
|
||||
Position outside of the sequence are not taken into account for computing the loss.
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.longformer.return_dict
|
||||
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
global_attention_mask = inputs[2]
|
||||
@ -1520,15 +1788,13 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
|
||||
logits = self.qa_outputs(sequence_output)
|
||||
start_logits, end_logits = tf.split(logits, 2, axis=-1)
|
||||
start_logits = tf.squeeze(start_logits, axis=-1)
|
||||
end_logits = tf.squeeze(end_logits, axis=-1)
|
||||
|
||||
loss = None
|
||||
|
||||
if start_positions is not None and end_positions is not None:
|
||||
labels = {"start_position": start_positions}
|
||||
labels["end_position"] = end_positions
|
||||
@ -1536,6 +1802,7 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
|
||||
|
||||
if not return_dict:
|
||||
output = (start_logits, end_logits) + outputs[2:]
|
||||
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TFQuestionAnsweringModelOutput(
|
||||
|
@ -31,7 +31,6 @@ from .file_utils import (
|
||||
add_start_docstrings_to_callable,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from .modeling_tf_bert import TFBertIntermediate
|
||||
from .modeling_tf_outputs import (
|
||||
TFBaseModelOutput,
|
||||
TFBaseModelOutputWithPooling,
|
||||
@ -63,11 +62,29 @@ _CONFIG_FOR_DOC = "MobileBertConfig"
|
||||
_TOKENIZER_FOR_DOC = "MobileBertTokenizer"
|
||||
|
||||
TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"mobilebert-uncased",
|
||||
"google/mobilebert-uncased",
|
||||
# See all MobileBERT models at https://huggingface.co/models?filter=mobilebert
|
||||
]
|
||||
|
||||
|
||||
class TFMobileBertIntermediate(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.dense = tf.keras.layers.Dense(config.intermediate_size, name="dense")
|
||||
|
||||
if isinstance(config.hidden_act, str):
|
||||
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
|
||||
else:
|
||||
self.intermediate_act_fn = config.hidden_act
|
||||
|
||||
def call(self, hidden_states):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class TFLayerNorm(tf.keras.layers.LayerNormalization):
|
||||
def __init__(self, feat_size, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@ -353,12 +370,6 @@ class TFMobileBertAttention(tf.keras.layers.Layer):
|
||||
return outputs
|
||||
|
||||
|
||||
class TFMobileBertIntermediate(TFBertIntermediate):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(config, **kwargs)
|
||||
self.dense = tf.keras.layers.Dense(config.intermediate_size, name="dense")
|
||||
|
||||
|
||||
class TFOutputBottleneck(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
@ -26,8 +26,8 @@ from .file_utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_callable,
|
||||
)
|
||||
from .modeling_tf_bert import TFBertEmbeddings, TFBertMainLayer
|
||||
from .modeling_tf_outputs import (
|
||||
TFBaseModelOutput,
|
||||
TFBaseModelOutputWithPooling,
|
||||
TFMaskedLMOutput,
|
||||
TFMultipleChoiceModelOutput,
|
||||
@ -64,14 +64,48 @@ TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
]
|
||||
|
||||
|
||||
class TFRobertaEmbeddings(TFBertEmbeddings):
|
||||
class TFRobertaEmbeddings(tf.keras.layers.Layer):
|
||||
"""
|
||||
Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
|
||||
"""
|
||||
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(config, **kwargs)
|
||||
|
||||
self.padding_idx = 1
|
||||
self.vocab_size = config.vocab_size
|
||||
self.hidden_size = config.hidden_size
|
||||
self.initializer_range = config.initializer_range
|
||||
self.position_embeddings = tf.keras.layers.Embedding(
|
||||
config.max_position_embeddings,
|
||||
config.hidden_size,
|
||||
embeddings_initializer=get_initializer(self.initializer_range),
|
||||
name="position_embeddings",
|
||||
)
|
||||
self.token_type_embeddings = tf.keras.layers.Embedding(
|
||||
config.type_vocab_size,
|
||||
config.hidden_size,
|
||||
embeddings_initializer=get_initializer(self.initializer_range),
|
||||
name="token_type_embeddings",
|
||||
)
|
||||
|
||||
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
||||
# any TensorFlow checkpoint file
|
||||
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
|
||||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def build(self, input_shape):
|
||||
"""Build shared word embedding layer """
|
||||
with tf.name_scope("word_embeddings"):
|
||||
# Create and initialize weights. The random normal initializer was chosen
|
||||
# arbitrarily, and works well.
|
||||
self.word_embeddings = self.add_weight(
|
||||
"weight",
|
||||
shape=[self.vocab_size, self.hidden_size],
|
||||
initializer=get_initializer(self.initializer_range),
|
||||
)
|
||||
|
||||
super().build(input_shape)
|
||||
|
||||
def create_position_ids_from_input_ids(self, x):
|
||||
"""Replace non-padding symbols with their position numbers. Position numbers begin at
|
||||
@ -82,6 +116,7 @@ class TFRobertaEmbeddings(TFBertEmbeddings):
|
||||
"""
|
||||
mask = tf.cast(tf.math.not_equal(x, self.padding_idx), dtype=tf.int32)
|
||||
incremental_indicies = tf.math.cumsum(mask, axis=1) * mask
|
||||
|
||||
return incremental_indicies + self.padding_idx
|
||||
|
||||
def create_position_ids_from_inputs_embeds(self, inputs_embeds):
|
||||
@ -91,10 +126,40 @@ class TFRobertaEmbeddings(TFBertEmbeddings):
|
||||
:return tf.Tensor:
|
||||
"""
|
||||
seq_length = shape_list(inputs_embeds)[1]
|
||||
|
||||
position_ids = tf.range(self.padding_idx + 1, seq_length + self.padding_idx + 1, dtype=tf.int32)[tf.newaxis, :]
|
||||
|
||||
return position_ids
|
||||
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
position_ids=None,
|
||||
token_type_ids=None,
|
||||
inputs_embeds=None,
|
||||
mode="embedding",
|
||||
training=False,
|
||||
):
|
||||
"""Get token embeddings of inputs.
|
||||
Args:
|
||||
inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
|
||||
mode: string, a valid value is one of "embedding" and "linear".
|
||||
Returns:
|
||||
outputs: (1) If mode == "embedding", output embedding tensor, float32 with
|
||||
shape [batch_size, length, embedding_size]; (2) mode == "linear", output
|
||||
linear tensor, float32 with shape [batch_size, length, vocab_size].
|
||||
Raises:
|
||||
ValueError: if mode is not valid.
|
||||
|
||||
Shared weights logic adapted from
|
||||
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
|
||||
"""
|
||||
if mode == "embedding":
|
||||
return self._embedding(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
|
||||
elif mode == "linear":
|
||||
return self._linear(input_ids)
|
||||
else:
|
||||
raise ValueError("mode {} is not valid.".format(mode))
|
||||
|
||||
def _embedding(self, input_ids, position_ids, token_type_ids, inputs_embeds, training=False):
|
||||
"""Applies embedding based on inputs tensor."""
|
||||
assert not (input_ids is None and inputs_embeds is None)
|
||||
@ -106,19 +171,429 @@ class TFRobertaEmbeddings(TFBertEmbeddings):
|
||||
else:
|
||||
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
|
||||
|
||||
return super()._embedding(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
|
||||
if input_ids is not None:
|
||||
input_shape = shape_list(input_ids)
|
||||
else:
|
||||
input_shape = shape_list(inputs_embeds)[:-1]
|
||||
|
||||
seq_length = input_shape[1]
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = tf.range(seq_length, dtype=tf.int32)[tf.newaxis, :]
|
||||
|
||||
if token_type_ids is None:
|
||||
token_type_ids = tf.fill(input_shape, 0)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = tf.gather(self.word_embeddings, input_ids)
|
||||
|
||||
position_embeddings = tf.cast(self.position_embeddings(position_ids), inputs_embeds.dtype)
|
||||
token_type_embeddings = tf.cast(self.token_type_embeddings(token_type_ids), inputs_embeds.dtype)
|
||||
embeddings = inputs_embeds + position_embeddings + token_type_embeddings
|
||||
embeddings = self.LayerNorm(embeddings)
|
||||
embeddings = self.dropout(embeddings, training=training)
|
||||
|
||||
return embeddings
|
||||
|
||||
def _linear(self, inputs):
|
||||
"""Computes logits by running inputs through a linear layer.
|
||||
Args:
|
||||
inputs: A float32 tensor with shape [batch_size, length, hidden_size]
|
||||
Returns:
|
||||
float32 tensor with shape [batch_size, length, vocab_size].
|
||||
"""
|
||||
batch_size = shape_list(inputs)[0]
|
||||
length = shape_list(inputs)[1]
|
||||
x = tf.reshape(inputs, [-1, self.hidden_size])
|
||||
logits = tf.matmul(x, self.word_embeddings, transpose_b=True)
|
||||
|
||||
return tf.reshape(logits, [batch_size, length, self.vocab_size])
|
||||
|
||||
|
||||
# Copied from transformers.modeling_tf_bert.TFBertPooler
|
||||
class TFRobertaPooler(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.dense = tf.keras.layers.Dense(
|
||||
config.hidden_size,
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
activation="tanh",
|
||||
name="dense",
|
||||
)
|
||||
|
||||
def call(self, hidden_states):
|
||||
# We "pool" the model by simply taking the hidden state corresponding
|
||||
# to the first token.
|
||||
first_token_tensor = hidden_states[:, 0]
|
||||
pooled_output = self.dense(first_token_tensor)
|
||||
|
||||
return pooled_output
|
||||
|
||||
|
||||
# Copied from transformers.modeling_tf_bert.TFBertSelfAttention
|
||||
class TFRobertaSelfAttention(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if config.hidden_size % config.num_attention_heads != 0:
|
||||
raise ValueError(
|
||||
"The hidden size (%d) is not a multiple of the number of attention "
|
||||
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
|
||||
)
|
||||
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
assert config.hidden_size % config.num_attention_heads == 0
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
self.query = tf.keras.layers.Dense(
|
||||
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
|
||||
)
|
||||
self.key = tf.keras.layers.Dense(
|
||||
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
|
||||
)
|
||||
self.value = tf.keras.layers.Dense(
|
||||
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
|
||||
)
|
||||
self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
def transpose_for_scores(self, x, batch_size):
|
||||
x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size))
|
||||
|
||||
return tf.transpose(x, perm=[0, 2, 1, 3])
|
||||
|
||||
def call(self, hidden_states, attention_mask, head_mask, output_attentions, training=False):
|
||||
batch_size = shape_list(hidden_states)[0]
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
mixed_key_layer = self.key(hidden_states)
|
||||
mixed_value_layer = self.value(hidden_states)
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
|
||||
key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
|
||||
value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = tf.matmul(
|
||||
query_layer, key_layer, transpose_b=True
|
||||
) # (batch size, num_heads, seq_len_q, seq_len_k)
|
||||
dk = tf.cast(shape_list(key_layer)[-1], attention_scores.dtype) # scale attention_scores
|
||||
attention_scores = attention_scores / tf.math.sqrt(dk)
|
||||
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask is (precomputed for all layers in TFBertModel call() function)
|
||||
attention_scores = attention_scores + attention_mask
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = tf.nn.softmax(attention_scores, axis=-1)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs, training=training)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
context_layer = tf.matmul(attention_probs, value_layer)
|
||||
context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3])
|
||||
context_layer = tf.reshape(
|
||||
context_layer, (batch_size, -1, self.all_head_size)
|
||||
) # (batch_size, seq_len_q, all_head_size)
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.modeling_tf_bert.TFBertSelfOutput
|
||||
class TFRobertaSelfOutput(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.dense = tf.keras.layers.Dense(
|
||||
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
||||
)
|
||||
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
|
||||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def call(self, hidden_states, input_tensor, training=False):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.modeling_tf_bert.TFBertAttention with Bert->Roberta
|
||||
class TFRobertaAttention(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.self_attention = TFRobertaSelfAttention(config, name="self")
|
||||
self.dense_output = TFRobertaSelfOutput(config, name="output")
|
||||
|
||||
def prune_heads(self, heads):
|
||||
raise NotImplementedError
|
||||
|
||||
def call(self, input_tensor, attention_mask, head_mask, output_attentions, training=False):
|
||||
self_outputs = self.self_attention(
|
||||
input_tensor, attention_mask, head_mask, output_attentions, training=training
|
||||
)
|
||||
attention_output = self.dense_output(self_outputs[0], input_tensor, training=training)
|
||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.modeling_tf_bert.TFBertIntermediate
|
||||
class TFRobertaIntermediate(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.dense = tf.keras.layers.Dense(
|
||||
config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
||||
)
|
||||
|
||||
if isinstance(config.hidden_act, str):
|
||||
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
|
||||
else:
|
||||
self.intermediate_act_fn = config.hidden_act
|
||||
|
||||
def call(self, hidden_states):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.modeling_tf_bert.TFBertOutput
|
||||
class TFRobertaOutput(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.dense = tf.keras.layers.Dense(
|
||||
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
||||
)
|
||||
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
|
||||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def call(self, hidden_states, input_tensor, training=False):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.modeling_tf_bert.TFBertLayer with Bert->Roberta
|
||||
class TFRobertaLayer(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.attention = TFRobertaAttention(config, name="attention")
|
||||
self.intermediate = TFRobertaIntermediate(config, name="intermediate")
|
||||
self.bert_output = TFRobertaOutput(config, name="output")
|
||||
|
||||
def call(self, hidden_states, attention_mask, head_mask, output_attentions, training=False):
|
||||
attention_outputs = self.attention(
|
||||
hidden_states, attention_mask, head_mask, output_attentions, training=training
|
||||
)
|
||||
attention_output = attention_outputs[0]
|
||||
intermediate_output = self.intermediate(attention_output)
|
||||
layer_output = self.bert_output(intermediate_output, attention_output, training=training)
|
||||
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.modeling_tf_bert.TFBertEncoder with Bert->Roberta
|
||||
class TFRobertaEncoder(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.layer = [TFRobertaLayer(config, name="layer_._{}".format(i)) for i in range(config.num_hidden_layers)]
|
||||
|
||||
def call(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
return_dict,
|
||||
training=False,
|
||||
):
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_attentions = () if output_attentions else None
|
||||
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer_outputs = layer_module(
|
||||
hidden_states, attention_mask, head_mask[i], output_attentions, training=training
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_attentions = all_attentions + (layer_outputs[1],)
|
||||
|
||||
# Add last layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
|
||||
|
||||
return TFBaseModelOutput(
|
||||
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
|
||||
)
|
||||
|
||||
|
||||
@keras_serializable
|
||||
class TFRobertaMainLayer(TFBertMainLayer):
|
||||
"""
|
||||
Same as TFBertMainLayer but uses TFRobertaEmbeddings.
|
||||
"""
|
||||
class TFRobertaMainLayer(tf.keras.layers.Layer):
|
||||
config_class = RobertaConfig
|
||||
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(config, **kwargs)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.num_hidden_layers = config.num_hidden_layers
|
||||
self.initializer_range = config.initializer_range
|
||||
self.output_attentions = config.output_attentions
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
self.return_dict = config.use_return_dict
|
||||
self.encoder = TFRobertaEncoder(config, name="encoder")
|
||||
self.pooler = TFRobertaPooler(config, name="pooler")
|
||||
# The embeddings must be the last declaration in order to follow the weights order
|
||||
self.embeddings = TFRobertaEmbeddings(config, name="embeddings")
|
||||
|
||||
# Copied from transformers.modeling_tf_bert.TFBertMainLayer.get_input_embeddings
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings
|
||||
|
||||
# Copied from transformers.modeling_tf_bert.TFBertMainLayer.set_input_embeddings
|
||||
def set_input_embeddings(self, value):
|
||||
self.embeddings.word_embeddings = value
|
||||
self.embeddings.vocab_size = value.shape[0]
|
||||
|
||||
# Copied from transformers.modeling_tf_bert.TFBertMainLayer._prune_heads
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
"""Prunes heads of the model.
|
||||
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
||||
See base class PreTrainedModel
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
# Copied from transformers.modeling_tf_bert.TFBertMainLayer.call
|
||||
def call(
|
||||
self,
|
||||
inputs,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
):
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
|
||||
position_ids = inputs[3] if len(inputs) > 3 else position_ids
|
||||
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
||||
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
|
||||
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
|
||||
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
|
||||
return_dict = inputs[8] if len(inputs) > 8 else return_dict
|
||||
assert len(inputs) <= 9, "Too many inputs."
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
input_ids = inputs.get("input_ids")
|
||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
||||
position_ids = inputs.get("position_ids", position_ids)
|
||||
head_mask = inputs.get("head_mask", head_mask)
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||
return_dict = inputs.get("return_dict", return_dict)
|
||||
assert len(inputs) <= 9, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
|
||||
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
|
||||
return_dict = return_dict if return_dict is not None else self.return_dict
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = shape_list(input_ids)
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = shape_list(inputs_embeds)[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = tf.fill(input_shape, 1)
|
||||
|
||||
if token_type_ids is None:
|
||||
token_type_ids = tf.fill(input_shape, 0)
|
||||
|
||||
embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
|
||||
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
# positions we want to attend and -10000.0 for masked positions.
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
extended_attention_mask = tf.cast(extended_attention_mask, embedding_output.dtype)
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
if head_mask is not None:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
head_mask = [None] * self.num_hidden_layers
|
||||
# head_mask = tf.constant([0] * self.num_hidden_layers)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
extended_attention_mask,
|
||||
head_mask,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
return_dict,
|
||||
training=training,
|
||||
)
|
||||
|
||||
sequence_output = encoder_outputs[0]
|
||||
pooled_output = self.pooler(sequence_output)
|
||||
|
||||
if not return_dict:
|
||||
return (
|
||||
sequence_output,
|
||||
pooled_output,
|
||||
) + encoder_outputs[1:]
|
||||
|
||||
return TFBaseModelOutputWithPooling(
|
||||
last_hidden_state=sequence_output,
|
||||
pooler_output=pooled_output,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
class TFRobertaPreTrainedModel(TFPreTrainedModel):
|
||||
"""An abstract class to handle weights initialization and
|
||||
@ -246,6 +721,7 @@ class TFRobertaLMHead(tf.keras.layers.Layer):
|
||||
|
||||
def __init__(self, config, input_embeddings, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.vocab_size = config.vocab_size
|
||||
self.dense = tf.keras.layers.Dense(
|
||||
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
||||
@ -259,6 +735,7 @@ class TFRobertaLMHead(tf.keras.layers.Layer):
|
||||
|
||||
def build(self, input_shape):
|
||||
self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias")
|
||||
|
||||
super().build(input_shape)
|
||||
|
||||
def call(self, features):
|
||||
|
@ -17,7 +17,6 @@
|
||||
|
||||
|
||||
import itertools
|
||||
import math
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
@ -114,13 +113,12 @@ def get_masks(slen, lengths, causal, padding_mask=None, dtype=tf.float32):
|
||||
return mask, attn_mask
|
||||
|
||||
|
||||
class TFMultiHeadAttention(tf.keras.layers.Layer):
|
||||
|
||||
class TFXLMMultiHeadAttention(tf.keras.layers.Layer):
|
||||
NEW_ID = itertools.count()
|
||||
|
||||
def __init__(self, n_heads, dim, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.layer_id = next(TFMultiHeadAttention.NEW_ID)
|
||||
self.layer_id = next(TFXLMMultiHeadAttention.NEW_ID)
|
||||
self.dim = dim
|
||||
self.n_heads = n_heads
|
||||
self.output_attentions = config.output_attentions
|
||||
@ -143,13 +141,15 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
|
||||
# Input is (bs, qlen, dim)
|
||||
# Mask is (bs, klen) (non-causal) or (bs, klen, klen)
|
||||
bs, qlen, dim = shape_list(input)
|
||||
|
||||
if kv is None:
|
||||
klen = qlen if cache is None else cache["slen"] + qlen
|
||||
else:
|
||||
klen = shape_list(kv)[1]
|
||||
|
||||
# assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
|
||||
n_heads = self.n_heads
|
||||
dim_per_head = self.dim // n_heads
|
||||
dim_per_head = tf.math.divide(self.dim, self.n_heads)
|
||||
dim_per_head = tf.cast(dim_per_head, dtype=tf.int32)
|
||||
mask_reshape = (bs, 1, qlen, klen) if len(shape_list(mask)) == 3 else (bs, 1, 1, klen)
|
||||
|
||||
def shape(x):
|
||||
@ -161,6 +161,7 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
|
||||
return tf.reshape(tf.transpose(x, perm=(0, 2, 1, 3)), (bs, -1, self.n_heads * dim_per_head))
|
||||
|
||||
q = shape(self.q_lin(input)) # (bs, n_heads, qlen, dim_per_head)
|
||||
|
||||
if kv is None:
|
||||
k = shape(self.k_lin(input)) # (bs, n_heads, qlen, dim_per_head)
|
||||
v = shape(self.v_lin(input)) # (bs, n_heads, qlen, dim_per_head)
|
||||
@ -177,14 +178,17 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
|
||||
v = tf.concat([v_, v], axis=2) # (bs, n_heads, klen, dim_per_head)
|
||||
else:
|
||||
k, v = cache[self.layer_id]
|
||||
|
||||
cache[self.layer_id] = (k, v)
|
||||
|
||||
q = q / math.sqrt(dim_per_head) # (bs, n_heads, qlen, dim_per_head)
|
||||
q = tf.cast(q, dtype=tf.float32)
|
||||
q = tf.multiply(q, tf.math.rsqrt(tf.cast(dim_per_head, dtype=tf.float32))) # (bs, n_heads, qlen, dim_per_head)
|
||||
k = tf.cast(k, dtype=q.dtype)
|
||||
scores = tf.matmul(q, k, transpose_b=True) # (bs, n_heads, qlen, klen)
|
||||
mask = tf.reshape(mask, mask_reshape) # (bs, n_heads, qlen, klen)
|
||||
# scores.masked_fill_(mask, -float('inf')) # (bs, n_heads, qlen, klen)
|
||||
mask = tf.cast(mask, dtype=scores.dtype)
|
||||
scores = scores - 1e30 * (1.0 - mask)
|
||||
|
||||
weights = tf.nn.softmax(scores, axis=-1) # (bs, n_heads, qlen, klen)
|
||||
weights = self.dropout(weights, training=training) # (bs, n_heads, qlen, klen)
|
||||
|
||||
@ -194,16 +198,18 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
|
||||
|
||||
context = tf.matmul(weights, v) # (bs, n_heads, qlen, dim_per_head)
|
||||
context = unshape(context) # (bs, qlen, dim)
|
||||
|
||||
outputs = (self.out_lin(context),)
|
||||
|
||||
if output_attentions:
|
||||
outputs = outputs + (weights,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class TFTransformerFFN(tf.keras.layers.Layer):
|
||||
class TFXLMTransformerFFN(tf.keras.layers.Layer):
|
||||
def __init__(self, in_dim, dim_hidden, out_dim, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.lin1 = tf.keras.layers.Dense(dim_hidden, kernel_initializer=get_initializer(config.init_std), name="lin1")
|
||||
self.lin2 = tf.keras.layers.Dense(out_dim, kernel_initializer=get_initializer(config.init_std), name="lin2")
|
||||
self.act = get_tf_activation("gelu") if config.gelu_activation else get_tf_activation("relu")
|
||||
@ -214,6 +220,7 @@ class TFTransformerFFN(tf.keras.layers.Layer):
|
||||
x = self.act(x)
|
||||
x = self.lin2(x)
|
||||
x = self.dropout(x, training=training)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@ -223,6 +230,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
|
||||
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
self.output_attentions = config.output_attentions
|
||||
self.return_dict = config.use_return_dict
|
||||
@ -230,8 +238,10 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
|
||||
# encoder / decoder, output layer
|
||||
self.is_encoder = config.is_encoder
|
||||
self.is_decoder = not config.is_encoder
|
||||
|
||||
if self.is_decoder:
|
||||
raise NotImplementedError("Currently XLM can only be used as an encoder")
|
||||
|
||||
# self.with_output = with_output
|
||||
self.causal = config.causal
|
||||
|
||||
@ -257,16 +267,17 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
|
||||
# embeddings
|
||||
self.dropout = tf.keras.layers.Dropout(config.dropout)
|
||||
self.attention_dropout = tf.keras.layers.Dropout(config.attention_dropout)
|
||||
|
||||
self.position_embeddings = tf.keras.layers.Embedding(
|
||||
config.max_position_embeddings,
|
||||
self.dim,
|
||||
embeddings_initializer=get_initializer(config.embed_init_std),
|
||||
name="position_embeddings",
|
||||
)
|
||||
|
||||
if config.sinusoidal_embeddings:
|
||||
raise NotImplementedError
|
||||
# create_sinusoidal_embeddings(config.max_position_embeddings, self.dim, out=self.position_embeddings.weight)
|
||||
|
||||
if config.n_langs > 1 and config.use_lang_emb:
|
||||
self.lang_embeddings = tf.keras.layers.Embedding(
|
||||
self.n_langs,
|
||||
@ -274,6 +285,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
|
||||
embeddings_initializer=get_initializer(config.embed_init_std),
|
||||
name="lang_embeddings",
|
||||
)
|
||||
|
||||
self.embeddings = TFSharedEmbeddings(
|
||||
self.n_words, self.dim, initializer_range=config.embed_init_std, name="embeddings"
|
||||
) # padding_idx=self.pad_index)
|
||||
@ -290,7 +302,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
|
||||
|
||||
for i in range(self.n_layers):
|
||||
self.attentions.append(
|
||||
TFMultiHeadAttention(self.n_heads, self.dim, config=config, name="attentions_._{}".format(i))
|
||||
TFXLMMultiHeadAttention(self.n_heads, self.dim, config=config, name="attentions_._{}".format(i))
|
||||
)
|
||||
self.layer_norm1.append(
|
||||
tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm1_._{}".format(i))
|
||||
@ -299,7 +311,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
|
||||
# self.layer_norm15.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps))
|
||||
# self.encoder_attn.append(MultiHeadAttention(self.n_heads, self.dim, dropout=self.attention_dropout))
|
||||
self.ffns.append(
|
||||
TFTransformerFFN(self.dim, self.hidden_dim, self.dim, config=config, name="ffns_._{}".format(i))
|
||||
TFXLMTransformerFFN(self.dim, self.hidden_dim, self.dim, config=config, name="ffns_._{}".format(i))
|
||||
)
|
||||
self.layer_norm2.append(
|
||||
tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm2_._{}".format(i))
|
||||
@ -308,6 +320,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
|
||||
if hasattr(config, "pruned_heads"):
|
||||
pruned_heads = config.pruned_heads.copy().items()
|
||||
config.pruned_heads = {}
|
||||
|
||||
for layer, heads in pruned_heads:
|
||||
if self.attentions[int(layer)].n_heads == config.n_heads:
|
||||
self.prune_heads({int(layer): list(map(int, heads))})
|
||||
@ -398,7 +411,9 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
|
||||
|
||||
# check inputs
|
||||
# assert shape_list(lengths)[0] == bs
|
||||
tf.debugging.assert_equal(shape_list(lengths)[0], bs)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(lengths)[0], bs
|
||||
), f"Expected batch size {shape_list(lengths)[0]} and received batch size {bs} mismatched"
|
||||
# assert lengths.max().item() <= slen
|
||||
# input_ids = input_ids.transpose(0, 1) # batch size as dimension 0
|
||||
# assert (src_enc is None) == (src_len is None)
|
||||
@ -416,13 +431,17 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
|
||||
position_ids = tf.expand_dims(tf.range(slen), axis=0)
|
||||
else:
|
||||
# assert shape_list(position_ids) == [bs, slen] # (slen, bs)
|
||||
tf.debugging.assert_equal(shape_list(position_ids), [bs, slen])
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(position_ids), [bs, slen]
|
||||
), f"Position id shape {shape_list(position_ids)} and input shape {[bs, slen]} mismatched"
|
||||
# position_ids = position_ids.transpose(0, 1)
|
||||
|
||||
# langs
|
||||
if langs is not None:
|
||||
# assert shape_list(langs) == [bs, slen] # (slen, bs)
|
||||
tf.debugging.assert_equal(shape_list(langs), [bs, slen])
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(langs), [bs, slen]
|
||||
), f"Lang shape {shape_list(langs)} and input shape {[bs, slen]} mismatched"
|
||||
# langs = langs.transpose(0, 1)
|
||||
|
||||
# Prepare head mask if needed
|
||||
@ -455,6 +474,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
|
||||
tensor = tensor + self.lang_embeddings(langs)
|
||||
if token_type_ids is not None:
|
||||
tensor = tensor + self.embeddings(token_type_ids)
|
||||
|
||||
tensor = self.layer_norm_emb(tensor)
|
||||
tensor = self.dropout(tensor, training=training)
|
||||
tensor = tensor * mask[..., tf.newaxis]
|
||||
@ -462,6 +482,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
|
||||
# transformer layers
|
||||
hidden_states = () if output_hidden_states else None
|
||||
attentions = () if output_attentions else None
|
||||
|
||||
for i in range(self.n_layers):
|
||||
if output_hidden_states:
|
||||
hidden_states = hidden_states + (tensor,)
|
||||
@ -471,8 +492,10 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
|
||||
tensor, attn_mask, None, cache, head_mask[i], output_attentions, training=training
|
||||
)
|
||||
attn = attn_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
attentions = attentions + (attn_outputs[1],)
|
||||
|
||||
attn = self.dropout(attn, training=training)
|
||||
tensor = tensor + attn
|
||||
tensor = self.layer_norm1[i](tensor)
|
||||
@ -502,6 +525,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [tensor, hidden_states, attentions] if v is not None)
|
||||
|
||||
return TFBaseModelOutput(last_hidden_state=tensor, hidden_states=hidden_states, attentions=attentions)
|
||||
|
||||
|
||||
@ -691,9 +715,11 @@ class TFXLMPredLayer(tf.keras.layers.Layer):
|
||||
|
||||
def __init__(self, config, input_embeddings, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.asm = config.asm
|
||||
self.n_words = config.n_words
|
||||
self.pad_index = config.pad_index
|
||||
|
||||
if config.asm is False:
|
||||
self.input_embeddings = input_embeddings
|
||||
else:
|
||||
@ -709,11 +735,13 @@ class TFXLMPredLayer(tf.keras.layers.Layer):
|
||||
def build(self, input_shape):
|
||||
# The output weights are the same as the input embeddings, but there is an output-only bias for each token.
|
||||
self.bias = self.add_weight(shape=(self.n_words,), initializer="zeros", trainable=True, name="bias")
|
||||
|
||||
super().build(input_shape)
|
||||
|
||||
def call(self, hidden_states):
|
||||
hidden_states = self.input_embeddings(hidden_states, mode="linear")
|
||||
hidden_states = hidden_states + self.bias
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user