TF: unpack inputs on Convbert, GPTJ, LED, and templates (#16491)

* Add unpack_inputs to remaining models

* remove stray use of inputs in the templates; fix tf.debugging of attn masks
This commit is contained in:
Joao Gante 2022-03-30 17:12:27 +01:00 committed by GitHub
parent ae189ef991
commit c2f8eaf6bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 463 additions and 930 deletions

View File

@ -943,12 +943,12 @@ class TFBartDecoder(tf.keras.layers.Layer):
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they # The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager. # have to be disabled in other modes than eager.
for attn_mask in [head_mask, cross_attn_head_mask]: for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
if attn_mask is not None and tf.executing_eagerly(): if attn_mask is not None and tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(attn_mask)[0], shape_list(attn_mask)[0],
len(self.layers), len(self.layers),
message=f"The {attn_mask} should be specified for {len(self.layers)} layers, but it is for {shape_list(attn_mask)[0]}.", message=f"The {attn_mask_name} should be specified for {len(self.layers)} layers, but it is for {shape_list(attn_mask)[0]}.",
) )
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):

View File

@ -39,8 +39,8 @@ from ...modeling_tf_utils import (
TFSequenceSummary, TFSequenceSummary,
TFTokenClassificationLoss, TFTokenClassificationLoss,
get_initializer, get_initializer,
input_processing,
keras_serializable, keras_serializable,
unpack_inputs,
) )
from ...tf_utils import shape_list from ...tf_utils import shape_list
from ...utils import ( from ...utils import (
@ -568,6 +568,7 @@ class TFConvBertMainLayer(tf.keras.layers.Layer):
return head_mask return head_mask
@unpack_inputs
def call( def call(
self, self,
input_ids=None, input_ids=None,
@ -582,60 +583,36 @@ class TFConvBertMainLayer(tf.keras.layers.Layer):
training=False, training=False,
**kwargs, **kwargs,
): ):
inputs = input_processing( if input_ids is not None and inputs_embeds is not None:
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif inputs["input_ids"] is not None: elif input_ids is not None:
input_shape = shape_list(inputs["input_ids"]) input_shape = shape_list(input_ids)
elif inputs["inputs_embeds"] is not None: elif inputs_embeds is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1] input_shape = shape_list(inputs_embeds)[:-1]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs["attention_mask"] is None: if attention_mask is None:
inputs["attention_mask"] = tf.fill(input_shape, 1) attention_mask = tf.fill(input_shape, 1)
if inputs["token_type_ids"] is None: if token_type_ids is None:
inputs["token_type_ids"] = tf.fill(input_shape, 0) token_type_ids = tf.fill(input_shape, 0)
hidden_states = self.embeddings( hidden_states = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
inputs["input_ids"], extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, hidden_states.dtype)
inputs["position_ids"], head_mask = self.get_head_mask(head_mask)
inputs["token_type_ids"],
inputs["inputs_embeds"],
training=inputs["training"],
)
extended_attention_mask = self.get_extended_attention_mask(
inputs["attention_mask"], input_shape, hidden_states.dtype
)
inputs["head_mask"] = self.get_head_mask(inputs["head_mask"])
if hasattr(self, "embeddings_project"): if hasattr(self, "embeddings_project"):
hidden_states = self.embeddings_project(hidden_states, training=inputs["training"]) hidden_states = self.embeddings_project(hidden_states, training=training)
hidden_states = self.encoder( hidden_states = self.encoder(
hidden_states, hidden_states,
extended_attention_mask, extended_attention_mask,
inputs["head_mask"], head_mask,
inputs["output_attentions"], output_attentions,
inputs["output_hidden_states"], output_hidden_states,
inputs["return_dict"], return_dict,
training=inputs["training"], training=training,
) )
return hidden_states return hidden_states
@ -754,6 +731,7 @@ class TFConvBertModel(TFConvBertPreTrainedModel):
self.convbert = TFConvBertMainLayer(config, name="convbert") self.convbert = TFConvBertMainLayer(config, name="convbert")
@unpack_inputs
@add_start_docstrings_to_model_forward(CONVBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(CONVBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
@ -775,9 +753,7 @@ class TFConvBertModel(TFConvBertPreTrainedModel):
training=False, training=False,
**kwargs, **kwargs,
): ):
inputs = input_processing( outputs = self.convbert(
func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
@ -788,19 +764,6 @@ class TFConvBertModel(TFConvBertPreTrainedModel):
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
training=training, training=training,
kwargs_call=kwargs,
)
outputs = self.convbert(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
return outputs return outputs
@ -886,6 +849,7 @@ class TFConvBertForMaskedLM(TFConvBertPreTrainedModel, TFMaskedLanguageModelingL
def get_prefix_bias_name(self): def get_prefix_bias_name(self):
return self.name + "/" + self.generator_lm_head.name return self.name + "/" + self.generator_lm_head.name
@unpack_inputs
@add_start_docstrings_to_model_forward(CONVBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(CONVBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
@ -914,9 +878,7 @@ class TFConvBertForMaskedLM(TFConvBertPreTrainedModel, TFMaskedLanguageModelingL
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
""" """
inputs = input_processing( generator_hidden_states = self.convbert(
func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
@ -926,28 +888,14 @@ class TFConvBertForMaskedLM(TFConvBertPreTrainedModel, TFMaskedLanguageModelingL
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
generator_hidden_states = self.convbert(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
generator_sequence_output = generator_hidden_states[0] generator_sequence_output = generator_hidden_states[0]
prediction_scores = self.generator_predictions(generator_sequence_output, training=inputs["training"]) prediction_scores = self.generator_predictions(generator_sequence_output, training=training)
prediction_scores = self.generator_lm_head(prediction_scores, training=inputs["training"]) prediction_scores = self.generator_lm_head(prediction_scores, training=training)
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], prediction_scores) loss = None if labels is None else self.hf_compute_loss(labels, prediction_scores)
if not inputs["return_dict"]: if not return_dict:
output = (prediction_scores,) + generator_hidden_states[1:] output = (prediction_scores,) + generator_hidden_states[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
@ -1010,6 +958,7 @@ class TFConvBertForSequenceClassification(TFConvBertPreTrainedModel, TFSequenceC
self.convbert = TFConvBertMainLayer(config, name="convbert") self.convbert = TFConvBertMainLayer(config, name="convbert")
self.classifier = TFConvBertClassificationHead(config, name="classifier") self.classifier = TFConvBertClassificationHead(config, name="classifier")
@unpack_inputs
@add_start_docstrings_to_model_forward(CONVBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(CONVBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
@ -1038,10 +987,8 @@ class TFConvBertForSequenceClassification(TFConvBertPreTrainedModel, TFSequenceC
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy). `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
""" """
inputs = input_processing( outputs = self.convbert(
func=self.call, input_ids,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
@ -1050,26 +997,12 @@ class TFConvBertForSequenceClassification(TFConvBertPreTrainedModel, TFSequenceC
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
) )
outputs = self.convbert( logits = self.classifier(outputs[0], training=training)
inputs["input_ids"], loss = None if labels is None else self.hf_compute_loss(labels, logits)
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
logits = self.classifier(outputs[0], training=inputs["training"])
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], logits)
if not inputs["return_dict"]: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
@ -1117,6 +1050,7 @@ class TFConvBertForMultipleChoice(TFConvBertPreTrainedModel, TFMultipleChoiceLos
""" """
return {"input_ids": tf.convert_to_tensor(MULTIPLE_CHOICE_DUMMY_INPUTS)} return {"input_ids": tf.convert_to_tensor(MULTIPLE_CHOICE_DUMMY_INPUTS)}
@unpack_inputs
@add_start_docstrings_to_model_forward( @add_start_docstrings_to_model_forward(
CONVBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length") CONVBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
) )
@ -1146,43 +1080,20 @@ class TFConvBertForMultipleChoice(TFConvBertPreTrainedModel, TFMultipleChoiceLos
Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above) where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above)
""" """
inputs = input_processing( if input_ids is not None:
func=self.call, num_choices = shape_list(input_ids)[1]
config=self.config, seq_length = shape_list(input_ids)[2]
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
if inputs["input_ids"] is not None:
num_choices = shape_list(inputs["input_ids"])[1]
seq_length = shape_list(inputs["input_ids"])[2]
else: else:
num_choices = shape_list(inputs["inputs_embeds"])[1] num_choices = shape_list(inputs_embeds)[1]
seq_length = shape_list(inputs["inputs_embeds"])[2] seq_length = shape_list(inputs_embeds)[2]
flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
flat_attention_mask = ( flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
) flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
flat_token_type_ids = (
tf.reshape(inputs["token_type_ids"], (-1, seq_length)) if inputs["token_type_ids"] is not None else None
)
flat_position_ids = (
tf.reshape(inputs["position_ids"], (-1, seq_length)) if inputs["position_ids"] is not None else None
)
flat_inputs_embeds = ( flat_inputs_embeds = (
tf.reshape(inputs["inputs_embeds"], (-1, seq_length, shape_list(inputs["inputs_embeds"])[3])) tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
if inputs["inputs_embeds"] is not None if inputs_embeds is not None
else None else None
) )
outputs = self.convbert( outputs = self.convbert(
@ -1190,19 +1101,19 @@ class TFConvBertForMultipleChoice(TFConvBertPreTrainedModel, TFMultipleChoiceLos
flat_attention_mask, flat_attention_mask,
flat_token_type_ids, flat_token_type_ids,
flat_position_ids, flat_position_ids,
inputs["head_mask"], head_mask,
flat_inputs_embeds, flat_inputs_embeds,
inputs["output_attentions"], output_attentions,
inputs["output_hidden_states"], output_hidden_states,
return_dict=inputs["return_dict"], return_dict=return_dict,
training=inputs["training"], training=training,
) )
logits = self.sequence_summary(outputs[0], training=inputs["training"]) logits = self.sequence_summary(outputs[0], training=training)
logits = self.classifier(logits) logits = self.classifier(logits)
reshaped_logits = tf.reshape(logits, (-1, num_choices)) reshaped_logits = tf.reshape(logits, (-1, num_choices))
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], reshaped_logits) loss = None if labels is None else self.hf_compute_loss(labels, reshaped_logits)
if not inputs["return_dict"]: if not return_dict:
output = (reshaped_logits,) + outputs[1:] output = (reshaped_logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
@ -1256,6 +1167,7 @@ class TFConvBertForTokenClassification(TFConvBertPreTrainedModel, TFTokenClassif
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
) )
@unpack_inputs
@add_start_docstrings_to_model_forward(CONVBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(CONVBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
@ -1282,10 +1194,8 @@ class TFConvBertForTokenClassification(TFConvBertPreTrainedModel, TFTokenClassif
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
""" """
inputs = input_processing( outputs = self.convbert(
func=self.call, input_ids,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
@ -1294,28 +1204,14 @@ class TFConvBertForTokenClassification(TFConvBertPreTrainedModel, TFTokenClassif
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
outputs = self.convbert(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output, training=inputs["training"]) sequence_output = self.dropout(sequence_output, training=training)
logits = self.classifier(sequence_output) logits = self.classifier(sequence_output)
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], logits) loss = None if labels is None else self.hf_compute_loss(labels, logits)
if not inputs["return_dict"]: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
@ -1350,6 +1246,7 @@ class TFConvBertForQuestionAnswering(TFConvBertPreTrainedModel, TFQuestionAnswer
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
) )
@unpack_inputs
@add_start_docstrings_to_model_forward(CONVBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(CONVBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
@ -1383,10 +1280,8 @@ class TFConvBertForQuestionAnswering(TFConvBertPreTrainedModel, TFQuestionAnswer
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss. are not taken into account for computing the loss.
""" """
inputs = input_processing( outputs = self.convbert(
func=self.call, input_ids,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
@ -1395,22 +1290,7 @@ class TFConvBertForQuestionAnswering(TFConvBertPreTrainedModel, TFQuestionAnswer
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
start_positions=start_positions,
end_positions=end_positions,
training=training, training=training,
kwargs_call=kwargs,
)
outputs = self.convbert(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
@ -1419,12 +1299,12 @@ class TFConvBertForQuestionAnswering(TFConvBertPreTrainedModel, TFQuestionAnswer
end_logits = tf.squeeze(end_logits, axis=-1) end_logits = tf.squeeze(end_logits, axis=-1)
loss = None loss = None
if inputs["start_positions"] is not None and inputs["end_positions"] is not None: if start_positions is not None and end_positions is not None:
labels = {"start_position": inputs["start_positions"]} labels = {"start_position": start_positions}
labels["end_position"] = inputs["end_positions"] labels["end_position"] = end_positions
loss = self.hf_compute_loss(labels, (start_logits, end_logits)) loss = self.hf_compute_loss(labels, (start_logits, end_logits))
if not inputs["return_dict"]: if not return_dict:
output = (start_logits, end_logits) + outputs[1:] output = (start_logits, end_logits) + outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output

View File

@ -40,7 +40,6 @@ from ...modeling_tf_utils import (
TFSequenceClassificationLoss, TFSequenceClassificationLoss,
TFSharedEmbeddings, TFSharedEmbeddings,
get_initializer, get_initializer,
input_processing,
keras_serializable, keras_serializable,
unpack_inputs, unpack_inputs,
) )
@ -376,6 +375,7 @@ class TFGPTJMainLayer(tf.keras.layers.Layer):
""" """
raise NotImplementedError raise NotImplementedError
@unpack_inputs
def call( def call(
self, self,
input_ids=None, input_ids=None,
@ -392,53 +392,34 @@ class TFGPTJMainLayer(tf.keras.layers.Layer):
training=False, training=False,
**kwargs, **kwargs,
): ):
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif inputs["input_ids"] is not None: elif input_ids is not None:
input_shape = shape_list(inputs["input_ids"]) input_shape = shape_list(input_ids)
inputs["input_ids"] = tf.reshape(inputs["input_ids"], [-1, input_shape[-1]]) input_ids = tf.reshape(input_ids, [-1, input_shape[-1]])
elif inputs["inputs_embeds"] is not None: elif inputs_embeds is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1] input_shape = shape_list(inputs_embeds)[:-1]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs["past_key_values"] is None: if past_key_values is None:
past_length = 0 past_length = 0
inputs["past_key_values"] = [None] * len(self.h) past_key_values = [None] * len(self.h)
else: else:
past_length = shape_list(inputs["past_key_values"][0][0])[-2] past_length = shape_list(past_key_values[0][0])[-2]
if inputs["position_ids"] is None: if position_ids is None:
inputs["position_ids"] = tf.expand_dims(tf.range(past_length, input_shape[-1] + past_length), axis=0) position_ids = tf.expand_dims(tf.range(past_length, input_shape[-1] + past_length), axis=0)
if inputs["attention_mask"] is not None: if attention_mask is not None:
# We create a 3D attention mask from a 2D tensor mask. # We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length] # Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention # this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here. # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask_shape = shape_list(inputs["attention_mask"]) attention_mask_shape = shape_list(attention_mask)
inputs["attention_mask"] = tf.reshape( attention_mask = tf.reshape(attention_mask, (attention_mask_shape[0], 1, 1, attention_mask_shape[1]))
inputs["attention_mask"], (attention_mask_shape[0], 1, 1, attention_mask_shape[1])
)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for # masked positions, this operation will create a tensor which is 0.0 for
@ -446,78 +427,74 @@ class TFGPTJMainLayer(tf.keras.layers.Layer):
# Since we are adding it to the raw scores before the softmax, this is # Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely. # effectively the same as removing these entirely.
one_cst = tf.constant(1.0) one_cst = tf.constant(1.0)
inputs["attention_mask"] = tf.cast(inputs["attention_mask"], dtype=one_cst.dtype) attention_mask = tf.cast(attention_mask, dtype=one_cst.dtype)
inputs["attention_mask"] = tf.multiply( attention_mask = tf.multiply(tf.subtract(one_cst, attention_mask), tf.constant(-10000.0))
tf.subtract(one_cst, inputs["attention_mask"]), tf.constant(-10000.0)
)
# Prepare head mask if needed # Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head # 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N # attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if inputs["head_mask"] is not None: if head_mask is not None:
raise NotImplementedError raise NotImplementedError
else: else:
inputs["head_mask"] = [None] * self.num_hidden_layers head_mask = [None] * self.num_hidden_layers
# head_mask = tf.constant([0] * self.num_hidden_layers) # head_mask = tf.constant([0] * self.num_hidden_layers)
inputs["position_ids"] = tf.reshape(inputs["position_ids"], [-1, shape_list(inputs["position_ids"])[-1]]) position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]])
if inputs["inputs_embeds"] is None: if inputs_embeds is None:
inputs["inputs_embeds"] = self.wte(inputs["input_ids"], mode="embedding") inputs_embeds = self.wte(input_ids, mode="embedding")
if inputs["token_type_ids"] is not None: if token_type_ids is not None:
inputs["token_type_ids"] = tf.reshape( token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]])
inputs["token_type_ids"], [-1, shape_list(inputs["token_type_ids"])[-1]] token_type_embeds = self.wte(token_type_ids, mode="embedding")
)
token_type_embeds = self.wte(inputs["token_type_ids"], mode="embedding")
else: else:
token_type_embeds = tf.constant(0.0) token_type_embeds = tf.constant(0.0)
token_type_embeds = tf.cast(token_type_embeds, dtype=inputs["inputs_embeds"].dtype) token_type_embeds = tf.cast(token_type_embeds, dtype=inputs_embeds.dtype)
hidden_states = inputs["inputs_embeds"] + token_type_embeds hidden_states = inputs_embeds + token_type_embeds
hidden_states = self.drop(hidden_states, training=inputs["training"]) hidden_states = self.drop(hidden_states, training=training)
output_shape = input_shape + [shape_list(hidden_states)[-1]] output_shape = input_shape + [shape_list(hidden_states)[-1]]
presents = () if inputs["use_cache"] else None presents = () if use_cache else None
all_attentions = () if inputs["output_attentions"] else None all_attentions = () if output_attentions else None
all_hidden_states = () if inputs["output_hidden_states"] else None all_hidden_states = () if output_hidden_states else None
for i, (block, layer_past) in enumerate(zip(self.h, inputs["past_key_values"])): for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
if inputs["output_hidden_states"]: if output_hidden_states:
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),) all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
outputs = block( outputs = block(
hidden_states, hidden_states,
layer_past, layer_past,
inputs["attention_mask"], attention_mask,
inputs["head_mask"][i], head_mask[i],
inputs["use_cache"], use_cache,
inputs["output_attentions"], output_attentions,
training=inputs["training"], training=training,
) )
hidden_states = outputs[0] hidden_states = outputs[0]
if inputs["use_cache"]: if use_cache:
presents = presents + (outputs[1],) presents = presents + (outputs[1],)
if inputs["output_attentions"]: if output_attentions:
all_attentions = all_attentions + (outputs[2 if inputs["use_cache"] else 1],) all_attentions = all_attentions + (outputs[2 if use_cache else 1],)
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
hidden_states = tf.reshape(hidden_states, output_shape) hidden_states = tf.reshape(hidden_states, output_shape)
# Add last hidden state # Add last hidden state
if inputs["output_hidden_states"]: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if inputs["output_attentions"]: if output_attentions:
# let the number of heads free (-1) so we can extract attention even after head pruning # let the number of heads free (-1) so we can extract attention even after head pruning
attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:] attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:]
all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions) all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions)
if not inputs["return_dict"]: if not return_dict:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None) return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None)
return TFBaseModelOutputWithPast( return TFBaseModelOutputWithPast(

View File

@ -30,8 +30,8 @@ from ...modeling_tf_utils import (
TFSharedEmbeddings, TFSharedEmbeddings,
TFWrappedEmbeddings, TFWrappedEmbeddings,
get_initializer, get_initializer,
input_processing,
keras_serializable, keras_serializable,
unpack_inputs,
) )
from ...tf_utils import shape_list from ...tf_utils import shape_list
from ...utils import ( from ...utils import (
@ -1654,6 +1654,7 @@ class TFLEDEncoder(tf.keras.layers.Layer):
def set_embed_tokens(self, embed_tokens): def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens self.embed_tokens = embed_tokens
@unpack_inputs
def call( def call(
self, self,
input_ids=None, input_ids=None,
@ -1703,95 +1704,74 @@ class TFLEDEncoder(tf.keras.layers.Layer):
return_dict (`bool`, *optional*): return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
""" """
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
global_attention_mask=global_attention_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif inputs["input_ids"] is not None: elif input_ids is not None:
input_shape = shape_list(inputs["input_ids"]) input_shape = shape_list(input_ids)
inputs["inputs_embeds"] = self.embed_tokens(inputs["input_ids"]) inputs_embeds = self.embed_tokens(input_ids)
elif inputs["inputs_embeds"] is not None: elif inputs_embeds is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1] input_shape = shape_list(inputs_embeds)[:-1]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs["attention_mask"] is None: if attention_mask is None:
inputs["attention_mask"] = tf.fill(input_shape, 1) attention_mask = tf.fill(input_shape, 1)
# merge `global_attention_mask` and `attention_mask` # merge `global_attention_mask` and `attention_mask`
if inputs["global_attention_mask"] is not None: if global_attention_mask is not None:
inputs["attention_mask"] = inputs["attention_mask"] * tf.cast( attention_mask = attention_mask * tf.cast((global_attention_mask + 1), dtype=attention_mask.dtype)
(inputs["global_attention_mask"] + 1), dtype=inputs["attention_mask"].dtype
)
( (padding_len, input_ids, attention_mask, inputs_embeds,) = self._pad_to_window_size(
padding_len, input_ids=input_ids,
inputs["input_ids"], attention_mask=attention_mask,
inputs["attention_mask"], inputs_embeds=inputs_embeds,
inputs["inputs_embeds"],
) = self._pad_to_window_size(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
inputs_embeds=inputs["inputs_embeds"],
pad_token_id=self.padding_idx, pad_token_id=self.padding_idx,
) )
input_shape = shape_list(inputs["attention_mask"]) input_shape = shape_list(attention_mask)
# is index masked or global attention # is index masked or global attention
is_index_masked = tf.math.less(tf.cast(inputs["attention_mask"], tf.int8), 1) is_index_masked = tf.math.less(tf.cast(attention_mask, tf.int8), 1)
is_index_global_attn = tf.math.greater(tf.cast(inputs["attention_mask"], tf.int8), 1) is_index_global_attn = tf.math.greater(tf.cast(attention_mask, tf.int8), 1)
is_global_attn = tf.math.reduce_any(is_index_global_attn) is_global_attn = tf.math.reduce_any(is_index_global_attn)
embed_pos = self.embed_positions(input_shape) embed_pos = self.embed_positions(input_shape)
hidden_states = inputs["inputs_embeds"] + embed_pos hidden_states = inputs_embeds + embed_pos
hidden_states = self.layernorm_embedding(hidden_states) hidden_states = self.layernorm_embedding(hidden_states)
hidden_states = self.dropout(hidden_states, training=inputs["training"]) hidden_states = self.dropout(hidden_states, training=training)
# check attention mask and invert # check attention mask and invert
if inputs["attention_mask"] is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
inputs["attention_mask"] = _expand_mask(inputs["attention_mask"])[:, 0, 0, :] attention_mask = _expand_mask(attention_mask)[:, 0, 0, :]
inputs["attention_mask"] = inputs["attention_mask"][:, :, None, None] attention_mask = attention_mask[:, :, None, None]
encoder_states = () if inputs["output_hidden_states"] else None encoder_states = () if output_hidden_states else None
all_attentions = all_global_attentions = () if inputs["output_attentions"] else None all_attentions = all_global_attentions = () if output_attentions else None
# check if head_mask has a correct number of layers specified if desired # check if head_mask has a correct number of layers specified if desired
if inputs["head_mask"] is not None and tf.executing_eagerly(): if head_mask is not None and tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(inputs["head_mask"])[0], shape_list(head_mask)[0],
len(self.layers), len(self.layers),
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(head_mask)[0]}.",
) )
# encoder layers # encoder layers
for idx, encoder_layer in enumerate(self.layers): for idx, encoder_layer in enumerate(self.layers):
if inputs["output_hidden_states"]: if output_hidden_states:
hidden_states_to_add = self.compute_hidden_states(hidden_states, padding_len) hidden_states_to_add = self.compute_hidden_states(hidden_states, padding_len)
encoder_states = encoder_states + (hidden_states_to_add,) encoder_states = encoder_states + (hidden_states_to_add,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = random.uniform(0, 1) dropout_probability = random.uniform(0, 1)
if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer if training and (dropout_probability < self.layerdrop): # skip the layer
continue continue
layer_outputs = encoder_layer( layer_outputs = encoder_layer(
hidden_states=hidden_states, hidden_states=hidden_states,
attention_mask=inputs["attention_mask"], attention_mask=attention_mask,
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None, layer_head_mask=head_mask[idx] if head_mask is not None else None,
is_index_masked=is_index_masked, is_index_masked=is_index_masked,
is_index_global_attn=is_index_global_attn, is_index_global_attn=is_index_global_attn,
is_global_attn=is_global_attn, is_global_attn=is_global_attn,
@ -1799,7 +1779,7 @@ class TFLEDEncoder(tf.keras.layers.Layer):
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if inputs["output_attentions"]: if output_attentions:
# bzs x seq_len x num_attn_heads x (num_global_attn + attention_window_len + 1) => bzs x num_attn_heads x seq_len x (num_global_attn + attention_window_len + 1) # bzs x seq_len x num_attn_heads x (num_global_attn + attention_window_len + 1) => bzs x num_attn_heads x seq_len x (num_global_attn + attention_window_len + 1)
all_attentions = all_attentions + (tf.transpose(layer_outputs[1], (0, 2, 1, 3)),) all_attentions = all_attentions + (tf.transpose(layer_outputs[1], (0, 2, 1, 3)),)
@ -1811,17 +1791,17 @@ class TFLEDEncoder(tf.keras.layers.Layer):
hidden_states = self.compute_hidden_states(hidden_states, padding_len) hidden_states = self.compute_hidden_states(hidden_states, padding_len)
# undo padding # undo padding
if inputs["output_attentions"]: if output_attentions:
all_attentions = ( all_attentions = (
tuple([state[:, :, :-padding_len, :] for state in all_attentions]) tuple([state[:, :, :-padding_len, :] for state in all_attentions])
if padding_len > 0 if padding_len > 0
else all_attentions else all_attentions
) )
if inputs["output_hidden_states"]: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
if not inputs["return_dict"]: if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
return TFLEDEncoderBaseModelOutput( return TFLEDEncoderBaseModelOutput(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
@ -1915,6 +1895,7 @@ class TFLEDDecoder(tf.keras.layers.Layer):
def set_embed_tokens(self, embed_tokens): def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens self.embed_tokens = embed_tokens
@unpack_inputs
def call( def call(
self, self,
input_ids=None, input_ids=None,
@ -1985,45 +1966,25 @@ class TFLEDDecoder(tf.keras.layers.Layer):
return_dict (`bool`, *optional*): return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
""" """
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask,
encoder_head_mask=encoder_head_mask,
inputs_embeds=inputs_embeds,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif inputs["input_ids"] is not None: elif input_ids is not None:
input_shape = shape_list(inputs["input_ids"]) input_shape = shape_list(input_ids)
elif inputs["inputs_embeds"] is not None: elif inputs_embeds is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1] input_shape = shape_list(inputs_embeds)[:-1]
else: else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
past_key_values_length = ( past_key_values_length = shape_list(past_key_values[0][0])[2] if past_key_values is not None else 0
shape_list(inputs["past_key_values"][0][0])[2] if inputs["past_key_values"] is not None else 0
)
# embed positions # embed positions
positions = self.embed_positions(input_shape, past_key_values_length) positions = self.embed_positions(input_shape, past_key_values_length)
if inputs["inputs_embeds"] is None: if inputs_embeds is None:
inputs["inputs_embeds"] = self.embed_tokens(inputs["input_ids"]) inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs["inputs_embeds"] hidden_states = inputs_embeds
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
if input_shape[-1] > 1: if input_shape[-1] > 1:
@ -2033,17 +1994,15 @@ class TFLEDDecoder(tf.keras.layers.Layer):
tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1] tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1]
) )
if inputs["attention_mask"] is not None and input_shape[-1] > 1: if attention_mask is not None and input_shape[-1] > 1:
combined_attention_mask = combined_attention_mask + _expand_mask( combined_attention_mask = combined_attention_mask + _expand_mask(attention_mask, tgt_len=input_shape[-1])
inputs["attention_mask"], tgt_len=input_shape[-1]
)
if inputs["encoder_hidden_states"] is not None and inputs["encoder_attention_mask"] is not None: if encoder_hidden_states is not None and encoder_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
inputs["encoder_attention_mask"] = _expand_mask(inputs["encoder_attention_mask"], tgt_len=input_shape[-1]) encoder_attention_mask = _expand_mask(encoder_attention_mask, tgt_len=input_shape[-1])
hidden_states = self.layernorm_embedding(hidden_states + positions) hidden_states = self.layernorm_embedding(hidden_states + positions)
hidden_states = self.dropout(hidden_states, training=inputs["training"]) hidden_states = self.dropout(hidden_states, training=training)
# decoder layers # decoder layers
all_hidden_states = () all_hidden_states = ()
@ -2052,54 +2011,52 @@ class TFLEDDecoder(tf.keras.layers.Layer):
present_key_values = () present_key_values = ()
# check if head_mask has a correct number of layers specified if desired # check if head_mask has a correct number of layers specified if desired
if inputs["head_mask"] is not None and tf.executing_eagerly(): if head_mask is not None and tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(inputs["head_mask"])[0], shape_list(head_mask)[0],
len(self.layers), len(self.layers),
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(head_mask)[0]}.",
) )
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if inputs["output_hidden_states"]: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
dropout_probability = random.uniform(0, 1) dropout_probability = random.uniform(0, 1)
if inputs["training"] and (dropout_probability < self.layerdrop): if training and (dropout_probability < self.layerdrop):
continue continue
past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer( hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer(
hidden_states, hidden_states,
attention_mask=combined_attention_mask, attention_mask=combined_attention_mask,
encoder_hidden_states=inputs["encoder_hidden_states"], encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=inputs["encoder_attention_mask"], encoder_attention_mask=encoder_attention_mask,
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None, layer_head_mask=head_mask[idx] if head_mask is not None else None,
encoder_layer_head_mask=inputs["encoder_head_mask"][idx] encoder_layer_head_mask=encoder_head_mask[idx] if encoder_head_mask is not None else None,
if inputs["encoder_head_mask"] is not None
else None,
past_key_value=past_key_value, past_key_value=past_key_value,
) )
if inputs["use_cache"]: if use_cache:
present_key_values += (present_key_value,) present_key_values += (present_key_value,)
if inputs["output_attentions"]: if output_attentions:
all_self_attns += (layer_self_attn,) all_self_attns += (layer_self_attn,)
all_cross_attentions += (layer_cross_attn,) all_cross_attentions += (layer_cross_attn,)
if inputs["output_hidden_states"]: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
else: else:
all_hidden_states = None all_hidden_states = None
all_self_attns = all_self_attns if inputs["output_attentions"] else None all_self_attns = all_self_attns if output_attentions else None
all_cross_attentions = all_cross_attentions if inputs["output_attentions"] else None all_cross_attentions = all_cross_attentions if output_attentions else None
present_key_values = present_key_values if inputs["use_cache"] else None present_key_values = present_key_values if use_cache else None
if not inputs["return_dict"]: if not return_dict:
return tuple( return tuple(
v v
for v in [hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attentions] for v in [hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attentions]
@ -2149,6 +2106,7 @@ class TFLEDMainLayer(tf.keras.layers.Layer):
self.encoder.set_embed_tokens(embed_tokens) self.encoder.set_embed_tokens(embed_tokens)
self.decoder.set_embed_tokens(embed_tokens) self.decoder.set_embed_tokens(embed_tokens)
@unpack_inputs
def call( def call(
self, self,
input_ids=None, input_ids=None,
@ -2169,72 +2127,51 @@ class TFLEDMainLayer(tf.keras.layers.Layer):
training=False, training=False,
**kwargs **kwargs
): ):
inputs = input_processing(
func=self.call, if decoder_input_ids is None and decoder_inputs_embeds is None:
config=self.config, use_cache = False
if encoder_outputs is None:
encoder_outputs = self.encoder(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
encoder_outputs=encoder_outputs,
global_attention_mask=global_attention_mask, global_attention_mask=global_attention_mask,
past_key_values=past_key_values, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds, output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
# If the user passed a tuple for encoder_outputs, we wrap it in a TFLEDEncoderBaseModelOutput when return_dict=True
elif return_dict and not isinstance(encoder_outputs, TFLEDEncoderBaseModelOutput):
encoder_outputs = TFLEDEncoderBaseModelOutput(
last_hidden_state=encoder_outputs[0],
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
)
# If the user passed a TFLEDEncoderBaseModelOutput for encoder_outputs, we wrap it in a tuple when return_dict=False
elif not return_dict and not isinstance(encoder_outputs, tuple):
encoder_outputs = encoder_outputs.to_tuple()
decoder_outputs = self.decoder(
decoder_input_ids,
attention_mask=decoder_attention_mask,
encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask,
encoder_head_mask=head_mask,
past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
training=training, training=training,
kwargs_call=kwargs,
) )
if inputs["decoder_input_ids"] is None and inputs["decoder_inputs_embeds"] is None: if not return_dict:
inputs["use_cache"] = False return decoder_outputs + encoder_outputs
if inputs["encoder_outputs"] is None:
inputs["encoder_outputs"] = self.encoder(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
global_attention_mask=inputs["global_attention_mask"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
# If the user passed a tuple for encoder_outputs, we wrap it in a TFLEDEncoderBaseModelOutput when return_dict=True
elif inputs["return_dict"] and not isinstance(inputs["encoder_outputs"], TFLEDEncoderBaseModelOutput):
inputs["encoder_outputs"] = TFLEDEncoderBaseModelOutput(
last_hidden_state=inputs["encoder_outputs"][0],
hidden_states=inputs["encoder_outputs"][1] if len(inputs["encoder_outputs"]) > 1 else None,
attentions=inputs["encoder_outputs"][2] if len(inputs["encoder_outputs"]) > 2 else None,
)
# If the user passed a TFLEDEncoderBaseModelOutput for encoder_outputs, we wrap it in a tuple when return_dict=False
elif not inputs["return_dict"] and not isinstance(inputs["encoder_outputs"], tuple):
inputs["encoder_outputs"] = inputs["encoder_outputs"].to_tuple()
decoder_outputs = self.decoder(
inputs["decoder_input_ids"],
attention_mask=inputs["decoder_attention_mask"],
encoder_hidden_states=inputs["encoder_outputs"][0],
encoder_attention_mask=inputs["attention_mask"],
head_mask=inputs["decoder_head_mask"],
encoder_head_mask=inputs["head_mask"],
past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["decoder_inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
if not inputs["return_dict"]:
return decoder_outputs + inputs["encoder_outputs"]
return TFLEDSeq2SeqModelOutput( return TFLEDSeq2SeqModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state, last_hidden_state=decoder_outputs.last_hidden_state,
@ -2242,10 +2179,10 @@ class TFLEDMainLayer(tf.keras.layers.Layer):
decoder_hidden_states=decoder_outputs.hidden_states, decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions, decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions, cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state, encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=inputs["encoder_outputs"].hidden_states, encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=inputs["encoder_outputs"].attentions, encoder_attentions=encoder_outputs.attentions,
encoder_global_attentions=inputs["encoder_outputs"].global_attentions, encoder_global_attentions=encoder_outputs.global_attentions,
) )
@ -2265,6 +2202,7 @@ class TFLEDModel(TFLEDPreTrainedModel):
def get_decoder(self): def get_decoder(self):
return self.led.decoder return self.led.decoder
@unpack_inputs
@add_start_docstrings_to_model_forward(LED_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(LED_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
@ -2292,17 +2230,16 @@ class TFLEDModel(TFLEDPreTrainedModel):
training=False, training=False,
**kwargs **kwargs
): ):
inputs = input_processing(
func=self.call, outputs = self.led(
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
global_attention_mask=global_attention_mask, global_attention_mask=global_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds,
@ -2311,25 +2248,6 @@ class TFLEDModel(TFLEDPreTrainedModel):
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
training=training, training=training,
kwargs_call=kwargs,
)
outputs = self.led(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
decoder_input_ids=inputs["decoder_input_ids"],
decoder_attention_mask=inputs["decoder_attention_mask"],
encoder_outputs=inputs["encoder_outputs"],
global_attention_mask=inputs["global_attention_mask"],
head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"],
past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"],
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
return outputs return outputs
@ -2393,6 +2311,7 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
def set_output_embeddings(self, value): def set_output_embeddings(self, value):
self.set_input_embeddings(value) self.set_input_embeddings(value)
@unpack_inputs
@add_start_docstrings_to_model_forward(LED_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(LED_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFLEDSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFLEDSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
def call( def call(
@ -2435,17 +2354,22 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
>>> # probs[5] is associated with the mask token >>> # probs[5] is associated with the mask token
```""" ```"""
inputs = input_processing( if labels is not None:
func=self.call, use_cache = False
config=self.config, if decoder_input_ids is None:
input_ids=input_ids, decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)
outputs = self.led(
input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
global_attention_mask=global_attention_mask, global_attention_mask=global_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds,
@ -2453,41 +2377,13 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
if inputs["labels"] is not None:
inputs["use_cache"] = False
if inputs["decoder_input_ids"] is None:
inputs["decoder_input_ids"] = shift_tokens_right(
inputs["labels"], self.config.pad_token_id, self.config.decoder_start_token_id
)
outputs = self.led(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
decoder_input_ids=inputs["decoder_input_ids"],
decoder_attention_mask=inputs["decoder_attention_mask"],
encoder_outputs=inputs["encoder_outputs"],
global_attention_mask=inputs["global_attention_mask"],
head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"],
past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"],
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
lm_logits = self.led.shared(outputs[0], mode="linear") lm_logits = self.led.shared(outputs[0], mode="linear")
lm_logits = lm_logits + self.final_logits_bias lm_logits = lm_logits + self.final_logits_bias
masked_lm_loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], lm_logits) masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits)
if not inputs["return_dict"]: if not return_dict:
output = (lm_logits,) + outputs[1:] output = (lm_logits,) + outputs[1:]
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
return TFLEDSeq2SeqLMOutput( return TFLEDSeq2SeqLMOutput(

View File

@ -965,12 +965,12 @@ class TFMBartDecoder(tf.keras.layers.Layer):
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they # The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager. # have to be disabled in other modes than eager.
for attn_mask in [head_mask, cross_attn_head_mask]: for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
if attn_mask is not None and tf.executing_eagerly(): if attn_mask is not None and tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(attn_mask)[0], shape_list(attn_mask)[0],
len(self.layers), len(self.layers),
message=f"The {attn_mask} should be specified for {len(self.layers)} layers, but it is for {shape_list(attn_mask)[0]}.", message=f"The {attn_mask_name} should be specified for {len(self.layers)} layers, but it is for {shape_list(attn_mask)[0]}.",
) )
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):

View File

@ -1060,12 +1060,12 @@ class TFSpeech2TextDecoder(tf.keras.layers.Layer):
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they have to be disabled in other modes than eager. # The tf.debugging asserts are not compliant with XLA then they have to be disabled in other modes than eager.
for attn_mask in [head_mask, cross_attn_head_mask]: for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
if attn_mask is not None and tf.executing_eagerly(): if attn_mask is not None and tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(attn_mask)[0], shape_list(attn_mask)[0],
len(self.layers), len(self.layers),
message=f"The {attn_mask} should be specified for {len(self.layers)} layers, but it is for {shape_list(attn_mask)[0]}.", message=f"The {attn_mask_name} should be specified for {len(self.layers)} layers, but it is for {shape_list(attn_mask)[0]}.",
) )
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):