mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
TF: unpack inputs on Convbert, GPTJ, LED, and templates (#16491)
* Add unpack_inputs to remaining models * remove stray use of inputs in the templates; fix tf.debugging of attn masks
This commit is contained in:
parent
ae189ef991
commit
c2f8eaf6bc
@ -943,12 +943,12 @@ class TFBartDecoder(tf.keras.layers.Layer):
|
||||
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
for attn_mask in [head_mask, cross_attn_head_mask]:
|
||||
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
|
||||
if attn_mask is not None and tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_mask)[0],
|
||||
len(self.layers),
|
||||
message=f"The {attn_mask} should be specified for {len(self.layers)} layers, but it is for {shape_list(attn_mask)[0]}.",
|
||||
message=f"The {attn_mask_name} should be specified for {len(self.layers)} layers, but it is for {shape_list(attn_mask)[0]}.",
|
||||
)
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
|
@ -39,8 +39,8 @@ from ...modeling_tf_utils import (
|
||||
TFSequenceSummary,
|
||||
TFTokenClassificationLoss,
|
||||
get_initializer,
|
||||
input_processing,
|
||||
keras_serializable,
|
||||
unpack_inputs,
|
||||
)
|
||||
from ...tf_utils import shape_list
|
||||
from ...utils import (
|
||||
@ -568,6 +568,7 @@ class TFConvBertMainLayer(tf.keras.layers.Layer):
|
||||
|
||||
return head_mask
|
||||
|
||||
@unpack_inputs
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
@ -582,60 +583,36 @@ class TFConvBertMainLayer(tf.keras.layers.Layer):
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
|
||||
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif inputs["input_ids"] is not None:
|
||||
input_shape = shape_list(inputs["input_ids"])
|
||||
elif inputs["inputs_embeds"] is not None:
|
||||
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
|
||||
elif input_ids is not None:
|
||||
input_shape = shape_list(input_ids)
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = shape_list(inputs_embeds)[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if inputs["attention_mask"] is None:
|
||||
inputs["attention_mask"] = tf.fill(input_shape, 1)
|
||||
if attention_mask is None:
|
||||
attention_mask = tf.fill(input_shape, 1)
|
||||
|
||||
if inputs["token_type_ids"] is None:
|
||||
inputs["token_type_ids"] = tf.fill(input_shape, 0)
|
||||
if token_type_ids is None:
|
||||
token_type_ids = tf.fill(input_shape, 0)
|
||||
|
||||
hidden_states = self.embeddings(
|
||||
inputs["input_ids"],
|
||||
inputs["position_ids"],
|
||||
inputs["token_type_ids"],
|
||||
inputs["inputs_embeds"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
extended_attention_mask = self.get_extended_attention_mask(
|
||||
inputs["attention_mask"], input_shape, hidden_states.dtype
|
||||
)
|
||||
inputs["head_mask"] = self.get_head_mask(inputs["head_mask"])
|
||||
hidden_states = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
|
||||
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, hidden_states.dtype)
|
||||
head_mask = self.get_head_mask(head_mask)
|
||||
|
||||
if hasattr(self, "embeddings_project"):
|
||||
hidden_states = self.embeddings_project(hidden_states, training=inputs["training"])
|
||||
hidden_states = self.embeddings_project(hidden_states, training=training)
|
||||
|
||||
hidden_states = self.encoder(
|
||||
hidden_states,
|
||||
extended_attention_mask,
|
||||
inputs["head_mask"],
|
||||
inputs["output_attentions"],
|
||||
inputs["output_hidden_states"],
|
||||
inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
head_mask,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
return_dict,
|
||||
training=training,
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
@ -754,6 +731,7 @@ class TFConvBertModel(TFConvBertPreTrainedModel):
|
||||
|
||||
self.convbert = TFConvBertMainLayer(config, name="convbert")
|
||||
|
||||
@unpack_inputs
|
||||
@add_start_docstrings_to_model_forward(CONVBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_TOKENIZER_FOR_DOC,
|
||||
@ -775,9 +753,7 @@ class TFConvBertModel(TFConvBertPreTrainedModel):
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
outputs = self.convbert(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
@ -788,19 +764,6 @@ class TFConvBertModel(TFConvBertPreTrainedModel):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
outputs = self.convbert(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
return outputs
|
||||
@ -886,6 +849,7 @@ class TFConvBertForMaskedLM(TFConvBertPreTrainedModel, TFMaskedLanguageModelingL
|
||||
def get_prefix_bias_name(self):
|
||||
return self.name + "/" + self.generator_lm_head.name
|
||||
|
||||
@unpack_inputs
|
||||
@add_start_docstrings_to_model_forward(CONVBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_TOKENIZER_FOR_DOC,
|
||||
@ -914,9 +878,7 @@ class TFConvBertForMaskedLM(TFConvBertPreTrainedModel, TFMaskedLanguageModelingL
|
||||
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
|
||||
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
|
||||
"""
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
generator_hidden_states = self.convbert(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
@ -926,28 +888,14 @@ class TFConvBertForMaskedLM(TFConvBertPreTrainedModel, TFMaskedLanguageModelingL
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
generator_hidden_states = self.convbert(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
generator_sequence_output = generator_hidden_states[0]
|
||||
prediction_scores = self.generator_predictions(generator_sequence_output, training=inputs["training"])
|
||||
prediction_scores = self.generator_lm_head(prediction_scores, training=inputs["training"])
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], prediction_scores)
|
||||
prediction_scores = self.generator_predictions(generator_sequence_output, training=training)
|
||||
prediction_scores = self.generator_lm_head(prediction_scores, training=training)
|
||||
loss = None if labels is None else self.hf_compute_loss(labels, prediction_scores)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
if not return_dict:
|
||||
output = (prediction_scores,) + generator_hidden_states[1:]
|
||||
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
@ -1010,6 +958,7 @@ class TFConvBertForSequenceClassification(TFConvBertPreTrainedModel, TFSequenceC
|
||||
self.convbert = TFConvBertMainLayer(config, name="convbert")
|
||||
self.classifier = TFConvBertClassificationHead(config, name="classifier")
|
||||
|
||||
@unpack_inputs
|
||||
@add_start_docstrings_to_model_forward(CONVBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_TOKENIZER_FOR_DOC,
|
||||
@ -1038,10 +987,8 @@ class TFConvBertForSequenceClassification(TFConvBertPreTrainedModel, TFSequenceC
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
input_ids=input_ids,
|
||||
outputs = self.convbert(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
@ -1050,26 +997,12 @@ class TFConvBertForSequenceClassification(TFConvBertPreTrainedModel, TFSequenceC
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
outputs = self.convbert(
|
||||
inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
logits = self.classifier(outputs[0], training=inputs["training"])
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], logits)
|
||||
logits = self.classifier(outputs[0], training=training)
|
||||
loss = None if labels is None else self.hf_compute_loss(labels, logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
@ -1117,6 +1050,7 @@ class TFConvBertForMultipleChoice(TFConvBertPreTrainedModel, TFMultipleChoiceLos
|
||||
"""
|
||||
return {"input_ids": tf.convert_to_tensor(MULTIPLE_CHOICE_DUMMY_INPUTS)}
|
||||
|
||||
@unpack_inputs
|
||||
@add_start_docstrings_to_model_forward(
|
||||
CONVBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
|
||||
)
|
||||
@ -1146,43 +1080,20 @@ class TFConvBertForMultipleChoice(TFConvBertPreTrainedModel, TFMultipleChoiceLos
|
||||
Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
|
||||
where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above)
|
||||
"""
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
|
||||
if inputs["input_ids"] is not None:
|
||||
num_choices = shape_list(inputs["input_ids"])[1]
|
||||
seq_length = shape_list(inputs["input_ids"])[2]
|
||||
if input_ids is not None:
|
||||
num_choices = shape_list(input_ids)[1]
|
||||
seq_length = shape_list(input_ids)[2]
|
||||
else:
|
||||
num_choices = shape_list(inputs["inputs_embeds"])[1]
|
||||
seq_length = shape_list(inputs["inputs_embeds"])[2]
|
||||
num_choices = shape_list(inputs_embeds)[1]
|
||||
seq_length = shape_list(inputs_embeds)[2]
|
||||
|
||||
flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None
|
||||
flat_attention_mask = (
|
||||
tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None
|
||||
)
|
||||
flat_token_type_ids = (
|
||||
tf.reshape(inputs["token_type_ids"], (-1, seq_length)) if inputs["token_type_ids"] is not None else None
|
||||
)
|
||||
flat_position_ids = (
|
||||
tf.reshape(inputs["position_ids"], (-1, seq_length)) if inputs["position_ids"] is not None else None
|
||||
)
|
||||
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
|
||||
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
|
||||
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
|
||||
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
|
||||
flat_inputs_embeds = (
|
||||
tf.reshape(inputs["inputs_embeds"], (-1, seq_length, shape_list(inputs["inputs_embeds"])[3]))
|
||||
if inputs["inputs_embeds"] is not None
|
||||
tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
|
||||
if inputs_embeds is not None
|
||||
else None
|
||||
)
|
||||
outputs = self.convbert(
|
||||
@ -1190,19 +1101,19 @@ class TFConvBertForMultipleChoice(TFConvBertPreTrainedModel, TFMultipleChoiceLos
|
||||
flat_attention_mask,
|
||||
flat_token_type_ids,
|
||||
flat_position_ids,
|
||||
inputs["head_mask"],
|
||||
head_mask,
|
||||
flat_inputs_embeds,
|
||||
inputs["output_attentions"],
|
||||
inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
)
|
||||
logits = self.sequence_summary(outputs[0], training=inputs["training"])
|
||||
logits = self.sequence_summary(outputs[0], training=training)
|
||||
logits = self.classifier(logits)
|
||||
reshaped_logits = tf.reshape(logits, (-1, num_choices))
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], reshaped_logits)
|
||||
loss = None if labels is None else self.hf_compute_loss(labels, reshaped_logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
if not return_dict:
|
||||
output = (reshaped_logits,) + outputs[1:]
|
||||
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
@ -1256,6 +1167,7 @@ class TFConvBertForTokenClassification(TFConvBertPreTrainedModel, TFTokenClassif
|
||||
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
|
||||
)
|
||||
|
||||
@unpack_inputs
|
||||
@add_start_docstrings_to_model_forward(CONVBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_TOKENIZER_FOR_DOC,
|
||||
@ -1282,10 +1194,8 @@ class TFConvBertForTokenClassification(TFConvBertPreTrainedModel, TFTokenClassif
|
||||
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
|
||||
"""
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
input_ids=input_ids,
|
||||
outputs = self.convbert(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
@ -1294,28 +1204,14 @@ class TFConvBertForTokenClassification(TFConvBertPreTrainedModel, TFTokenClassif
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
outputs = self.convbert(
|
||||
inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
sequence_output = outputs[0]
|
||||
sequence_output = self.dropout(sequence_output, training=inputs["training"])
|
||||
sequence_output = self.dropout(sequence_output, training=training)
|
||||
logits = self.classifier(sequence_output)
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], logits)
|
||||
loss = None if labels is None else self.hf_compute_loss(labels, logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
@ -1350,6 +1246,7 @@ class TFConvBertForQuestionAnswering(TFConvBertPreTrainedModel, TFQuestionAnswer
|
||||
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
|
||||
)
|
||||
|
||||
@unpack_inputs
|
||||
@add_start_docstrings_to_model_forward(CONVBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_TOKENIZER_FOR_DOC,
|
||||
@ -1383,10 +1280,8 @@ class TFConvBertForQuestionAnswering(TFConvBertPreTrainedModel, TFQuestionAnswer
|
||||
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
||||
are not taken into account for computing the loss.
|
||||
"""
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
input_ids=input_ids,
|
||||
outputs = self.convbert(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
@ -1395,22 +1290,7 @@ class TFConvBertForQuestionAnswering(TFConvBertPreTrainedModel, TFQuestionAnswer
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
start_positions=start_positions,
|
||||
end_positions=end_positions,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
outputs = self.convbert(
|
||||
inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
sequence_output = outputs[0]
|
||||
logits = self.qa_outputs(sequence_output)
|
||||
@ -1419,12 +1299,12 @@ class TFConvBertForQuestionAnswering(TFConvBertPreTrainedModel, TFQuestionAnswer
|
||||
end_logits = tf.squeeze(end_logits, axis=-1)
|
||||
loss = None
|
||||
|
||||
if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
|
||||
labels = {"start_position": inputs["start_positions"]}
|
||||
labels["end_position"] = inputs["end_positions"]
|
||||
if start_positions is not None and end_positions is not None:
|
||||
labels = {"start_position": start_positions}
|
||||
labels["end_position"] = end_positions
|
||||
loss = self.hf_compute_loss(labels, (start_logits, end_logits))
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
if not return_dict:
|
||||
output = (start_logits, end_logits) + outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
|
@ -40,7 +40,6 @@ from ...modeling_tf_utils import (
|
||||
TFSequenceClassificationLoss,
|
||||
TFSharedEmbeddings,
|
||||
get_initializer,
|
||||
input_processing,
|
||||
keras_serializable,
|
||||
unpack_inputs,
|
||||
)
|
||||
@ -376,6 +375,7 @@ class TFGPTJMainLayer(tf.keras.layers.Layer):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@unpack_inputs
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
@ -392,53 +392,34 @@ class TFGPTJMainLayer(tf.keras.layers.Layer):
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
input_ids=input_ids,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
|
||||
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif inputs["input_ids"] is not None:
|
||||
input_shape = shape_list(inputs["input_ids"])
|
||||
inputs["input_ids"] = tf.reshape(inputs["input_ids"], [-1, input_shape[-1]])
|
||||
elif inputs["inputs_embeds"] is not None:
|
||||
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
|
||||
elif input_ids is not None:
|
||||
input_shape = shape_list(input_ids)
|
||||
input_ids = tf.reshape(input_ids, [-1, input_shape[-1]])
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = shape_list(inputs_embeds)[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if inputs["past_key_values"] is None:
|
||||
if past_key_values is None:
|
||||
past_length = 0
|
||||
inputs["past_key_values"] = [None] * len(self.h)
|
||||
past_key_values = [None] * len(self.h)
|
||||
else:
|
||||
past_length = shape_list(inputs["past_key_values"][0][0])[-2]
|
||||
past_length = shape_list(past_key_values[0][0])[-2]
|
||||
|
||||
if inputs["position_ids"] is None:
|
||||
inputs["position_ids"] = tf.expand_dims(tf.range(past_length, input_shape[-1] + past_length), axis=0)
|
||||
if position_ids is None:
|
||||
position_ids = tf.expand_dims(tf.range(past_length, input_shape[-1] + past_length), axis=0)
|
||||
|
||||
if inputs["attention_mask"] is not None:
|
||||
if attention_mask is not None:
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
attention_mask_shape = shape_list(inputs["attention_mask"])
|
||||
inputs["attention_mask"] = tf.reshape(
|
||||
inputs["attention_mask"], (attention_mask_shape[0], 1, 1, attention_mask_shape[1])
|
||||
)
|
||||
attention_mask_shape = shape_list(attention_mask)
|
||||
attention_mask = tf.reshape(attention_mask, (attention_mask_shape[0], 1, 1, attention_mask_shape[1]))
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
@ -446,78 +427,74 @@ class TFGPTJMainLayer(tf.keras.layers.Layer):
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
one_cst = tf.constant(1.0)
|
||||
inputs["attention_mask"] = tf.cast(inputs["attention_mask"], dtype=one_cst.dtype)
|
||||
inputs["attention_mask"] = tf.multiply(
|
||||
tf.subtract(one_cst, inputs["attention_mask"]), tf.constant(-10000.0)
|
||||
)
|
||||
attention_mask = tf.cast(attention_mask, dtype=one_cst.dtype)
|
||||
attention_mask = tf.multiply(tf.subtract(one_cst, attention_mask), tf.constant(-10000.0))
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
if inputs["head_mask"] is not None:
|
||||
if head_mask is not None:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
inputs["head_mask"] = [None] * self.num_hidden_layers
|
||||
head_mask = [None] * self.num_hidden_layers
|
||||
# head_mask = tf.constant([0] * self.num_hidden_layers)
|
||||
|
||||
inputs["position_ids"] = tf.reshape(inputs["position_ids"], [-1, shape_list(inputs["position_ids"])[-1]])
|
||||
position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]])
|
||||
|
||||
if inputs["inputs_embeds"] is None:
|
||||
inputs["inputs_embeds"] = self.wte(inputs["input_ids"], mode="embedding")
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.wte(input_ids, mode="embedding")
|
||||
|
||||
if inputs["token_type_ids"] is not None:
|
||||
inputs["token_type_ids"] = tf.reshape(
|
||||
inputs["token_type_ids"], [-1, shape_list(inputs["token_type_ids"])[-1]]
|
||||
)
|
||||
token_type_embeds = self.wte(inputs["token_type_ids"], mode="embedding")
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]])
|
||||
token_type_embeds = self.wte(token_type_ids, mode="embedding")
|
||||
else:
|
||||
token_type_embeds = tf.constant(0.0)
|
||||
|
||||
token_type_embeds = tf.cast(token_type_embeds, dtype=inputs["inputs_embeds"].dtype)
|
||||
hidden_states = inputs["inputs_embeds"] + token_type_embeds
|
||||
hidden_states = self.drop(hidden_states, training=inputs["training"])
|
||||
token_type_embeds = tf.cast(token_type_embeds, dtype=inputs_embeds.dtype)
|
||||
hidden_states = inputs_embeds + token_type_embeds
|
||||
hidden_states = self.drop(hidden_states, training=training)
|
||||
|
||||
output_shape = input_shape + [shape_list(hidden_states)[-1]]
|
||||
|
||||
presents = () if inputs["use_cache"] else None
|
||||
all_attentions = () if inputs["output_attentions"] else None
|
||||
all_hidden_states = () if inputs["output_hidden_states"] else None
|
||||
for i, (block, layer_past) in enumerate(zip(self.h, inputs["past_key_values"])):
|
||||
if inputs["output_hidden_states"]:
|
||||
presents = () if use_cache else None
|
||||
all_attentions = () if output_attentions else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
|
||||
|
||||
outputs = block(
|
||||
hidden_states,
|
||||
layer_past,
|
||||
inputs["attention_mask"],
|
||||
inputs["head_mask"][i],
|
||||
inputs["use_cache"],
|
||||
inputs["output_attentions"],
|
||||
training=inputs["training"],
|
||||
attention_mask,
|
||||
head_mask[i],
|
||||
use_cache,
|
||||
output_attentions,
|
||||
training=training,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
if inputs["use_cache"]:
|
||||
if use_cache:
|
||||
presents = presents + (outputs[1],)
|
||||
|
||||
if inputs["output_attentions"]:
|
||||
all_attentions = all_attentions + (outputs[2 if inputs["use_cache"] else 1],)
|
||||
if output_attentions:
|
||||
all_attentions = all_attentions + (outputs[2 if use_cache else 1],)
|
||||
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
|
||||
hidden_states = tf.reshape(hidden_states, output_shape)
|
||||
# Add last hidden state
|
||||
if inputs["output_hidden_states"]:
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if inputs["output_attentions"]:
|
||||
if output_attentions:
|
||||
# let the number of heads free (-1) so we can extract attention even after head pruning
|
||||
attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:]
|
||||
all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None)
|
||||
|
||||
return TFBaseModelOutputWithPast(
|
||||
|
@ -30,8 +30,8 @@ from ...modeling_tf_utils import (
|
||||
TFSharedEmbeddings,
|
||||
TFWrappedEmbeddings,
|
||||
get_initializer,
|
||||
input_processing,
|
||||
keras_serializable,
|
||||
unpack_inputs,
|
||||
)
|
||||
from ...tf_utils import shape_list
|
||||
from ...utils import (
|
||||
@ -1654,6 +1654,7 @@ class TFLEDEncoder(tf.keras.layers.Layer):
|
||||
def set_embed_tokens(self, embed_tokens):
|
||||
self.embed_tokens = embed_tokens
|
||||
|
||||
@unpack_inputs
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
@ -1703,95 +1704,74 @@ class TFLEDEncoder(tf.keras.layers.Layer):
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
global_attention_mask=global_attention_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
|
||||
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif inputs["input_ids"] is not None:
|
||||
input_shape = shape_list(inputs["input_ids"])
|
||||
inputs["inputs_embeds"] = self.embed_tokens(inputs["input_ids"])
|
||||
elif inputs["inputs_embeds"] is not None:
|
||||
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
|
||||
elif input_ids is not None:
|
||||
input_shape = shape_list(input_ids)
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = shape_list(inputs_embeds)[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if inputs["attention_mask"] is None:
|
||||
inputs["attention_mask"] = tf.fill(input_shape, 1)
|
||||
if attention_mask is None:
|
||||
attention_mask = tf.fill(input_shape, 1)
|
||||
|
||||
# merge `global_attention_mask` and `attention_mask`
|
||||
if inputs["global_attention_mask"] is not None:
|
||||
inputs["attention_mask"] = inputs["attention_mask"] * tf.cast(
|
||||
(inputs["global_attention_mask"] + 1), dtype=inputs["attention_mask"].dtype
|
||||
)
|
||||
if global_attention_mask is not None:
|
||||
attention_mask = attention_mask * tf.cast((global_attention_mask + 1), dtype=attention_mask.dtype)
|
||||
|
||||
(
|
||||
padding_len,
|
||||
inputs["input_ids"],
|
||||
inputs["attention_mask"],
|
||||
inputs["inputs_embeds"],
|
||||
) = self._pad_to_window_size(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
(padding_len, input_ids, attention_mask, inputs_embeds,) = self._pad_to_window_size(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
pad_token_id=self.padding_idx,
|
||||
)
|
||||
|
||||
input_shape = shape_list(inputs["attention_mask"])
|
||||
input_shape = shape_list(attention_mask)
|
||||
# is index masked or global attention
|
||||
is_index_masked = tf.math.less(tf.cast(inputs["attention_mask"], tf.int8), 1)
|
||||
is_index_global_attn = tf.math.greater(tf.cast(inputs["attention_mask"], tf.int8), 1)
|
||||
is_index_masked = tf.math.less(tf.cast(attention_mask, tf.int8), 1)
|
||||
is_index_global_attn = tf.math.greater(tf.cast(attention_mask, tf.int8), 1)
|
||||
is_global_attn = tf.math.reduce_any(is_index_global_attn)
|
||||
|
||||
embed_pos = self.embed_positions(input_shape)
|
||||
hidden_states = inputs["inputs_embeds"] + embed_pos
|
||||
hidden_states = inputs_embeds + embed_pos
|
||||
hidden_states = self.layernorm_embedding(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states, training=inputs["training"])
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
|
||||
# check attention mask and invert
|
||||
if inputs["attention_mask"] is not None:
|
||||
if attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
inputs["attention_mask"] = _expand_mask(inputs["attention_mask"])[:, 0, 0, :]
|
||||
inputs["attention_mask"] = inputs["attention_mask"][:, :, None, None]
|
||||
attention_mask = _expand_mask(attention_mask)[:, 0, 0, :]
|
||||
attention_mask = attention_mask[:, :, None, None]
|
||||
|
||||
encoder_states = () if inputs["output_hidden_states"] else None
|
||||
all_attentions = all_global_attentions = () if inputs["output_attentions"] else None
|
||||
encoder_states = () if output_hidden_states else None
|
||||
all_attentions = all_global_attentions = () if output_attentions else None
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
if inputs["head_mask"] is not None and tf.executing_eagerly():
|
||||
if head_mask is not None and tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(inputs["head_mask"])[0],
|
||||
shape_list(head_mask)[0],
|
||||
len(self.layers),
|
||||
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
|
||||
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(head_mask)[0]}.",
|
||||
)
|
||||
|
||||
# encoder layers
|
||||
for idx, encoder_layer in enumerate(self.layers):
|
||||
|
||||
if inputs["output_hidden_states"]:
|
||||
if output_hidden_states:
|
||||
hidden_states_to_add = self.compute_hidden_states(hidden_states, padding_len)
|
||||
encoder_states = encoder_states + (hidden_states_to_add,)
|
||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||
dropout_probability = random.uniform(0, 1)
|
||||
if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer
|
||||
if training and (dropout_probability < self.layerdrop): # skip the layer
|
||||
continue
|
||||
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=inputs["attention_mask"],
|
||||
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
|
||||
attention_mask=attention_mask,
|
||||
layer_head_mask=head_mask[idx] if head_mask is not None else None,
|
||||
is_index_masked=is_index_masked,
|
||||
is_index_global_attn=is_index_global_attn,
|
||||
is_global_attn=is_global_attn,
|
||||
@ -1799,7 +1779,7 @@ class TFLEDEncoder(tf.keras.layers.Layer):
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if inputs["output_attentions"]:
|
||||
if output_attentions:
|
||||
# bzs x seq_len x num_attn_heads x (num_global_attn + attention_window_len + 1) => bzs x num_attn_heads x seq_len x (num_global_attn + attention_window_len + 1)
|
||||
all_attentions = all_attentions + (tf.transpose(layer_outputs[1], (0, 2, 1, 3)),)
|
||||
|
||||
@ -1811,17 +1791,17 @@ class TFLEDEncoder(tf.keras.layers.Layer):
|
||||
hidden_states = self.compute_hidden_states(hidden_states, padding_len)
|
||||
|
||||
# undo padding
|
||||
if inputs["output_attentions"]:
|
||||
if output_attentions:
|
||||
all_attentions = (
|
||||
tuple([state[:, :, :-padding_len, :] for state in all_attentions])
|
||||
if padding_len > 0
|
||||
else all_attentions
|
||||
)
|
||||
|
||||
if inputs["output_hidden_states"]:
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
||||
return TFLEDEncoderBaseModelOutput(
|
||||
last_hidden_state=hidden_states,
|
||||
@ -1915,6 +1895,7 @@ class TFLEDDecoder(tf.keras.layers.Layer):
|
||||
def set_embed_tokens(self, embed_tokens):
|
||||
self.embed_tokens = embed_tokens
|
||||
|
||||
@unpack_inputs
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
@ -1985,45 +1966,25 @@ class TFLEDDecoder(tf.keras.layers.Layer):
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_head_mask=encoder_head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
|
||||
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||||
elif inputs["input_ids"] is not None:
|
||||
input_shape = shape_list(inputs["input_ids"])
|
||||
elif inputs["inputs_embeds"] is not None:
|
||||
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
|
||||
elif input_ids is not None:
|
||||
input_shape = shape_list(input_ids)
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = shape_list(inputs_embeds)[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||
|
||||
past_key_values_length = (
|
||||
shape_list(inputs["past_key_values"][0][0])[2] if inputs["past_key_values"] is not None else 0
|
||||
)
|
||||
past_key_values_length = shape_list(past_key_values[0][0])[2] if past_key_values is not None else 0
|
||||
|
||||
# embed positions
|
||||
positions = self.embed_positions(input_shape, past_key_values_length)
|
||||
|
||||
if inputs["inputs_embeds"] is None:
|
||||
inputs["inputs_embeds"] = self.embed_tokens(inputs["input_ids"])
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
hidden_states = inputs["inputs_embeds"]
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
if input_shape[-1] > 1:
|
||||
@ -2033,17 +1994,15 @@ class TFLEDDecoder(tf.keras.layers.Layer):
|
||||
tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1]
|
||||
)
|
||||
|
||||
if inputs["attention_mask"] is not None and input_shape[-1] > 1:
|
||||
combined_attention_mask = combined_attention_mask + _expand_mask(
|
||||
inputs["attention_mask"], tgt_len=input_shape[-1]
|
||||
)
|
||||
if attention_mask is not None and input_shape[-1] > 1:
|
||||
combined_attention_mask = combined_attention_mask + _expand_mask(attention_mask, tgt_len=input_shape[-1])
|
||||
|
||||
if inputs["encoder_hidden_states"] is not None and inputs["encoder_attention_mask"] is not None:
|
||||
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
inputs["encoder_attention_mask"] = _expand_mask(inputs["encoder_attention_mask"], tgt_len=input_shape[-1])
|
||||
encoder_attention_mask = _expand_mask(encoder_attention_mask, tgt_len=input_shape[-1])
|
||||
|
||||
hidden_states = self.layernorm_embedding(hidden_states + positions)
|
||||
hidden_states = self.dropout(hidden_states, training=inputs["training"])
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = ()
|
||||
@ -2052,54 +2011,52 @@ class TFLEDDecoder(tf.keras.layers.Layer):
|
||||
present_key_values = ()
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
if inputs["head_mask"] is not None and tf.executing_eagerly():
|
||||
if head_mask is not None and tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(inputs["head_mask"])[0],
|
||||
shape_list(head_mask)[0],
|
||||
len(self.layers),
|
||||
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
|
||||
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(head_mask)[0]}.",
|
||||
)
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||
if inputs["output_hidden_states"]:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
dropout_probability = random.uniform(0, 1)
|
||||
|
||||
if inputs["training"] and (dropout_probability < self.layerdrop):
|
||||
if training and (dropout_probability < self.layerdrop):
|
||||
continue
|
||||
|
||||
past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=combined_attention_mask,
|
||||
encoder_hidden_states=inputs["encoder_hidden_states"],
|
||||
encoder_attention_mask=inputs["encoder_attention_mask"],
|
||||
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
|
||||
encoder_layer_head_mask=inputs["encoder_head_mask"][idx]
|
||||
if inputs["encoder_head_mask"] is not None
|
||||
else None,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=head_mask[idx] if head_mask is not None else None,
|
||||
encoder_layer_head_mask=encoder_head_mask[idx] if encoder_head_mask is not None else None,
|
||||
past_key_value=past_key_value,
|
||||
)
|
||||
|
||||
if inputs["use_cache"]:
|
||||
if use_cache:
|
||||
present_key_values += (present_key_value,)
|
||||
|
||||
if inputs["output_attentions"]:
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_self_attn,)
|
||||
all_cross_attentions += (layer_cross_attn,)
|
||||
|
||||
if inputs["output_hidden_states"]:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
else:
|
||||
all_hidden_states = None
|
||||
|
||||
all_self_attns = all_self_attns if inputs["output_attentions"] else None
|
||||
all_cross_attentions = all_cross_attentions if inputs["output_attentions"] else None
|
||||
all_self_attns = all_self_attns if output_attentions else None
|
||||
all_cross_attentions = all_cross_attentions if output_attentions else None
|
||||
|
||||
present_key_values = present_key_values if inputs["use_cache"] else None
|
||||
present_key_values = present_key_values if use_cache else None
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attentions]
|
||||
@ -2149,6 +2106,7 @@ class TFLEDMainLayer(tf.keras.layers.Layer):
|
||||
self.encoder.set_embed_tokens(embed_tokens)
|
||||
self.decoder.set_embed_tokens(embed_tokens)
|
||||
|
||||
@unpack_inputs
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
@ -2169,72 +2127,51 @@ class TFLEDMainLayer(tf.keras.layers.Layer):
|
||||
training=False,
|
||||
**kwargs
|
||||
):
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
decoder_head_mask=decoder_head_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
global_attention_mask=global_attention_mask,
|
||||
|
||||
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
||||
use_cache = False
|
||||
|
||||
if encoder_outputs is None:
|
||||
encoder_outputs = self.encoder(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
global_attention_mask=global_attention_mask,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
)
|
||||
# If the user passed a tuple for encoder_outputs, we wrap it in a TFLEDEncoderBaseModelOutput when return_dict=True
|
||||
elif return_dict and not isinstance(encoder_outputs, TFLEDEncoderBaseModelOutput):
|
||||
encoder_outputs = TFLEDEncoderBaseModelOutput(
|
||||
last_hidden_state=encoder_outputs[0],
|
||||
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
|
||||
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
||||
)
|
||||
# If the user passed a TFLEDEncoderBaseModelOutput for encoder_outputs, we wrap it in a tuple when return_dict=False
|
||||
elif not return_dict and not isinstance(encoder_outputs, tuple):
|
||||
encoder_outputs = encoder_outputs.to_tuple()
|
||||
|
||||
decoder_outputs = self.decoder(
|
||||
decoder_input_ids,
|
||||
attention_mask=decoder_attention_mask,
|
||||
encoder_hidden_states=encoder_outputs[0],
|
||||
encoder_attention_mask=attention_mask,
|
||||
head_mask=decoder_head_mask,
|
||||
encoder_head_mask=head_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||
inputs_embeds=decoder_inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
|
||||
if inputs["decoder_input_ids"] is None and inputs["decoder_inputs_embeds"] is None:
|
||||
inputs["use_cache"] = False
|
||||
|
||||
if inputs["encoder_outputs"] is None:
|
||||
inputs["encoder_outputs"] = self.encoder(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
global_attention_mask=inputs["global_attention_mask"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
# If the user passed a tuple for encoder_outputs, we wrap it in a TFLEDEncoderBaseModelOutput when return_dict=True
|
||||
elif inputs["return_dict"] and not isinstance(inputs["encoder_outputs"], TFLEDEncoderBaseModelOutput):
|
||||
inputs["encoder_outputs"] = TFLEDEncoderBaseModelOutput(
|
||||
last_hidden_state=inputs["encoder_outputs"][0],
|
||||
hidden_states=inputs["encoder_outputs"][1] if len(inputs["encoder_outputs"]) > 1 else None,
|
||||
attentions=inputs["encoder_outputs"][2] if len(inputs["encoder_outputs"]) > 2 else None,
|
||||
)
|
||||
# If the user passed a TFLEDEncoderBaseModelOutput for encoder_outputs, we wrap it in a tuple when return_dict=False
|
||||
elif not inputs["return_dict"] and not isinstance(inputs["encoder_outputs"], tuple):
|
||||
inputs["encoder_outputs"] = inputs["encoder_outputs"].to_tuple()
|
||||
|
||||
decoder_outputs = self.decoder(
|
||||
inputs["decoder_input_ids"],
|
||||
attention_mask=inputs["decoder_attention_mask"],
|
||||
encoder_hidden_states=inputs["encoder_outputs"][0],
|
||||
encoder_attention_mask=inputs["attention_mask"],
|
||||
head_mask=inputs["decoder_head_mask"],
|
||||
encoder_head_mask=inputs["head_mask"],
|
||||
past_key_values=inputs["past_key_values"],
|
||||
inputs_embeds=inputs["decoder_inputs_embeds"],
|
||||
use_cache=inputs["use_cache"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
return decoder_outputs + inputs["encoder_outputs"]
|
||||
if not return_dict:
|
||||
return decoder_outputs + encoder_outputs
|
||||
|
||||
return TFLEDSeq2SeqModelOutput(
|
||||
last_hidden_state=decoder_outputs.last_hidden_state,
|
||||
@ -2242,10 +2179,10 @@ class TFLEDMainLayer(tf.keras.layers.Layer):
|
||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||
decoder_attentions=decoder_outputs.attentions,
|
||||
cross_attentions=decoder_outputs.cross_attentions,
|
||||
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state,
|
||||
encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
|
||||
encoder_attentions=inputs["encoder_outputs"].attentions,
|
||||
encoder_global_attentions=inputs["encoder_outputs"].global_attentions,
|
||||
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
||||
encoder_hidden_states=encoder_outputs.hidden_states,
|
||||
encoder_attentions=encoder_outputs.attentions,
|
||||
encoder_global_attentions=encoder_outputs.global_attentions,
|
||||
)
|
||||
|
||||
|
||||
@ -2265,6 +2202,7 @@ class TFLEDModel(TFLEDPreTrainedModel):
|
||||
def get_decoder(self):
|
||||
return self.led.decoder
|
||||
|
||||
@unpack_inputs
|
||||
@add_start_docstrings_to_model_forward(LED_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_TOKENIZER_FOR_DOC,
|
||||
@ -2292,17 +2230,16 @@ class TFLEDModel(TFLEDPreTrainedModel):
|
||||
training=False,
|
||||
**kwargs
|
||||
):
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
|
||||
outputs = self.led(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
decoder_head_mask=decoder_head_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
global_attention_mask=global_attention_mask,
|
||||
head_mask=head_mask,
|
||||
decoder_head_mask=decoder_head_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||
@ -2311,25 +2248,6 @@ class TFLEDModel(TFLEDPreTrainedModel):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
outputs = self.led(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
decoder_input_ids=inputs["decoder_input_ids"],
|
||||
decoder_attention_mask=inputs["decoder_attention_mask"],
|
||||
encoder_outputs=inputs["encoder_outputs"],
|
||||
global_attention_mask=inputs["global_attention_mask"],
|
||||
head_mask=inputs["head_mask"],
|
||||
decoder_head_mask=inputs["decoder_head_mask"],
|
||||
past_key_values=inputs["past_key_values"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
|
||||
use_cache=inputs["use_cache"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
return outputs
|
||||
@ -2393,6 +2311,7 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
|
||||
def set_output_embeddings(self, value):
|
||||
self.set_input_embeddings(value)
|
||||
|
||||
@unpack_inputs
|
||||
@add_start_docstrings_to_model_forward(LED_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=TFLEDSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def call(
|
||||
@ -2435,17 +2354,22 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
|
||||
>>> # probs[5] is associated with the mask token
|
||||
```"""
|
||||
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
input_ids=input_ids,
|
||||
if labels is not None:
|
||||
use_cache = False
|
||||
if decoder_input_ids is None:
|
||||
decoder_input_ids = shift_tokens_right(
|
||||
labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
||||
)
|
||||
|
||||
outputs = self.led(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
decoder_head_mask=decoder_head_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
global_attention_mask=global_attention_mask,
|
||||
head_mask=head_mask,
|
||||
decoder_head_mask=decoder_head_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||
@ -2453,41 +2377,13 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
|
||||
if inputs["labels"] is not None:
|
||||
inputs["use_cache"] = False
|
||||
if inputs["decoder_input_ids"] is None:
|
||||
inputs["decoder_input_ids"] = shift_tokens_right(
|
||||
inputs["labels"], self.config.pad_token_id, self.config.decoder_start_token_id
|
||||
)
|
||||
|
||||
outputs = self.led(
|
||||
inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
decoder_input_ids=inputs["decoder_input_ids"],
|
||||
decoder_attention_mask=inputs["decoder_attention_mask"],
|
||||
encoder_outputs=inputs["encoder_outputs"],
|
||||
global_attention_mask=inputs["global_attention_mask"],
|
||||
head_mask=inputs["head_mask"],
|
||||
decoder_head_mask=inputs["decoder_head_mask"],
|
||||
past_key_values=inputs["past_key_values"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
|
||||
use_cache=inputs["use_cache"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
lm_logits = self.led.shared(outputs[0], mode="linear")
|
||||
lm_logits = lm_logits + self.final_logits_bias
|
||||
masked_lm_loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], lm_logits)
|
||||
masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
if not return_dict:
|
||||
output = (lm_logits,) + outputs[1:]
|
||||
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
||||
return TFLEDSeq2SeqLMOutput(
|
||||
|
@ -965,12 +965,12 @@ class TFMBartDecoder(tf.keras.layers.Layer):
|
||||
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
for attn_mask in [head_mask, cross_attn_head_mask]:
|
||||
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
|
||||
if attn_mask is not None and tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_mask)[0],
|
||||
len(self.layers),
|
||||
message=f"The {attn_mask} should be specified for {len(self.layers)} layers, but it is for {shape_list(attn_mask)[0]}.",
|
||||
message=f"The {attn_mask_name} should be specified for {len(self.layers)} layers, but it is for {shape_list(attn_mask)[0]}.",
|
||||
)
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
|
@ -1060,12 +1060,12 @@ class TFSpeech2TextDecoder(tf.keras.layers.Layer):
|
||||
|
||||
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
||||
# The tf.debugging asserts are not compliant with XLA then they have to be disabled in other modes than eager.
|
||||
for attn_mask in [head_mask, cross_attn_head_mask]:
|
||||
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
|
||||
if attn_mask is not None and tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_mask)[0],
|
||||
len(self.layers),
|
||||
message=f"The {attn_mask} should be specified for {len(self.layers)} layers, but it is for {shape_list(attn_mask)[0]}.",
|
||||
message=f"The {attn_mask_name} should be specified for {len(self.layers)} layers, but it is for {shape_list(attn_mask)[0]}.",
|
||||
)
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user