mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
add unpack_inputs decorator to mbart (#16097)
This commit is contained in:
parent
3e9d0f7f59
commit
9042dfe35c
@ -42,8 +42,8 @@ from ...modeling_tf_utils import (
|
||||
TFPreTrainedModel,
|
||||
TFSharedEmbeddings,
|
||||
TFWrappedEmbeddings,
|
||||
input_processing,
|
||||
keras_serializable,
|
||||
unpack_inputs,
|
||||
)
|
||||
from ...tf_utils import shape_list
|
||||
from ...utils import logging
|
||||
@ -666,6 +666,7 @@ class TFMBartEncoder(tf.keras.layers.Layer):
|
||||
def set_embed_tokens(self, embed_tokens):
|
||||
self.embed_tokens = embed_tokens
|
||||
|
||||
@unpack_inputs
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
@ -720,82 +721,69 @@ class TFMBartEncoder(tf.keras.layers.Layer):
|
||||
Whether or not to use the model in training mode (some modules like dropout modules have different
|
||||
behaviors between training and evaluation).
|
||||
"""
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
|
||||
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif inputs["input_ids"] is not None:
|
||||
input_shape = shape_list(inputs["input_ids"])
|
||||
elif inputs["inputs_embeds"] is not None:
|
||||
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
|
||||
elif input_ids is not None:
|
||||
input_shape = shape_list(input_ids)
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = shape_list(inputs_embeds)[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if inputs["inputs_embeds"] is None:
|
||||
inputs["inputs_embeds"] = self.embed_tokens(inputs["input_ids"]) * self.embed_scale
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
|
||||
embed_pos = self.embed_positions(input_shape)
|
||||
hidden_states = inputs["inputs_embeds"] + embed_pos
|
||||
hidden_states = inputs_embeds + embed_pos
|
||||
hidden_states = self.layernorm_embedding(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states, training=inputs["training"])
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
|
||||
# check attention mask and invert
|
||||
if inputs["attention_mask"] is not None:
|
||||
if attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
attention_mask = _expand_mask(inputs["attention_mask"])
|
||||
attention_mask = _expand_mask(attention_mask)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
encoder_states = () if inputs["output_hidden_states"] else None
|
||||
all_attentions = () if inputs["output_attentions"] else None
|
||||
encoder_states = () if output_hidden_states else None
|
||||
all_attentions = () if output_attentions else None
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if inputs["head_mask"] is not None and tf.executing_eagerly():
|
||||
if head_mask is not None and tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(inputs["head_mask"])[0],
|
||||
shape_list(head_mask)[0],
|
||||
len(self.layers),
|
||||
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
|
||||
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(head_mask)[0]}.",
|
||||
)
|
||||
|
||||
# encoder layers
|
||||
for idx, encoder_layer in enumerate(self.layers):
|
||||
|
||||
if inputs["output_hidden_states"]:
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||
dropout_probability = random.uniform(0, 1)
|
||||
if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer
|
||||
if training and (dropout_probability < self.layerdrop): # skip the layer
|
||||
continue
|
||||
|
||||
hidden_states, attn = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
|
||||
head_mask[idx] if head_mask is not None else None,
|
||||
)
|
||||
|
||||
if inputs["output_attentions"]:
|
||||
if output_attentions:
|
||||
all_attentions += (attn,)
|
||||
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
|
||||
if inputs["output_hidden_states"]:
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
||||
return TFBaseModelOutput(
|
||||
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
||||
@ -837,6 +825,7 @@ class TFMBartDecoder(tf.keras.layers.Layer):
|
||||
def set_embed_tokens(self, embed_tokens):
|
||||
self.embed_tokens = embed_tokens
|
||||
|
||||
@unpack_inputs
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
@ -920,45 +909,25 @@ class TFMBartDecoder(tf.keras.layers.Layer):
|
||||
Whether or not to use the model in training mode (some modules like dropout modules have different
|
||||
behaviors between training and evaluation).
|
||||
"""
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
|
||||
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||||
elif inputs["input_ids"] is not None:
|
||||
input_shape = shape_list(inputs["input_ids"])
|
||||
elif inputs["inputs_embeds"] is not None:
|
||||
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
|
||||
elif input_ids is not None:
|
||||
input_shape = shape_list(input_ids)
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = shape_list(inputs_embeds)[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||
|
||||
past_key_values_length = (
|
||||
shape_list(inputs["past_key_values"][0][0])[2] if inputs["past_key_values"] is not None else 0
|
||||
)
|
||||
past_key_values_length = shape_list(past_key_values[0][0])[2] if past_key_values is not None else 0
|
||||
|
||||
# embed positions
|
||||
positions = self.embed_positions(input_shape, past_key_values_length)
|
||||
|
||||
if inputs["inputs_embeds"] is None:
|
||||
inputs["inputs_embeds"] = self.embed_tokens(inputs["input_ids"]) * self.embed_scale
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
|
||||
hidden_states = inputs["inputs_embeds"]
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
if input_shape[-1] > 1:
|
||||
@ -968,73 +937,69 @@ class TFMBartDecoder(tf.keras.layers.Layer):
|
||||
tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1]
|
||||
)
|
||||
|
||||
if inputs["attention_mask"] is not None:
|
||||
combined_attention_mask = combined_attention_mask + _expand_mask(
|
||||
inputs["attention_mask"], tgt_len=input_shape[-1]
|
||||
)
|
||||
if attention_mask is not None:
|
||||
combined_attention_mask = combined_attention_mask + _expand_mask(attention_mask, tgt_len=input_shape[-1])
|
||||
|
||||
if inputs["encoder_hidden_states"] is not None and inputs["encoder_attention_mask"] is not None:
|
||||
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
inputs["encoder_attention_mask"] = _expand_mask(inputs["encoder_attention_mask"], tgt_len=input_shape[-1])
|
||||
encoder_attention_mask = _expand_mask(encoder_attention_mask, tgt_len=input_shape[-1])
|
||||
|
||||
hidden_states = self.layernorm_embedding(hidden_states + positions)
|
||||
hidden_states = self.dropout(hidden_states, training=inputs["training"])
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if inputs["output_hidden_states"] else None
|
||||
all_self_attns = () if inputs["output_attentions"] else None
|
||||
all_cross_attns = () if (inputs["output_attentions"] and inputs["encoder_hidden_states"] is not None) else None
|
||||
present_key_values = () if inputs["use_cache"] else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
all_cross_attns = () if (output_attentions and encoder_hidden_states is not None) else None
|
||||
present_key_values = () if use_cache else None
|
||||
|
||||
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
for attn_mask in ["head_mask", "cross_attn_head_mask"]:
|
||||
if inputs[attn_mask] is not None and tf.executing_eagerly():
|
||||
for attn_mask in [head_mask, cross_attn_head_mask]:
|
||||
if attn_mask is not None and tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(inputs[attn_mask])[0],
|
||||
shape_list(attn_mask)[0],
|
||||
len(self.layers),
|
||||
message=f"The {attn_mask} should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs[attn_mask])[0]}.",
|
||||
message=f"The {attn_mask} should be specified for {len(self.layers)} layers, but it is for {shape_list(attn_mask)[0]}.",
|
||||
)
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||
if inputs["output_hidden_states"]:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
dropout_probability = random.uniform(0, 1)
|
||||
|
||||
if inputs["training"] and (dropout_probability < self.layerdrop):
|
||||
if training and (dropout_probability < self.layerdrop):
|
||||
continue
|
||||
|
||||
past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=combined_attention_mask,
|
||||
encoder_hidden_states=inputs["encoder_hidden_states"],
|
||||
encoder_attention_mask=inputs["encoder_attention_mask"],
|
||||
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
|
||||
cross_attn_layer_head_mask=inputs["cross_attn_head_mask"][idx]
|
||||
if inputs["cross_attn_head_mask"] is not None
|
||||
else None,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=head_mask[idx] if head_mask is not None else None,
|
||||
cross_attn_layer_head_mask=cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
|
||||
past_key_value=past_key_value,
|
||||
)
|
||||
|
||||
if inputs["use_cache"]:
|
||||
if use_cache:
|
||||
present_key_values += (present_key_value,)
|
||||
|
||||
if inputs["output_attentions"]:
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_self_attn,)
|
||||
|
||||
if inputs["encoder_hidden_states"] is not None:
|
||||
if encoder_hidden_states is not None:
|
||||
all_cross_attns += (layer_cross_attn,)
|
||||
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
|
||||
if inputs["output_hidden_states"]:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
if not return_dict:
|
||||
return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns
|
||||
else:
|
||||
return TFBaseModelOutputWithPastAndCrossAttentions(
|
||||
@ -1081,6 +1046,7 @@ class TFMBartMainLayer(tf.keras.layers.Layer):
|
||||
self.encoder.set_embed_tokens(embed_tokens)
|
||||
self.decoder.set_embed_tokens(embed_tokens)
|
||||
|
||||
@unpack_inputs
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
@ -1101,80 +1067,57 @@ class TFMBartMainLayer(tf.keras.layers.Layer):
|
||||
training=False,
|
||||
**kwargs
|
||||
):
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
decoder_head_mask=decoder_head_mask,
|
||||
|
||||
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
||||
use_cache = False
|
||||
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
if decoder_input_ids is None and input_ids is not None:
|
||||
decoder_input_ids = shift_tokens_right(input_ids, self.config.pad_token_id)
|
||||
|
||||
if encoder_outputs is None:
|
||||
encoder_outputs = self.encoder(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
)
|
||||
# If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True
|
||||
elif return_dict and not isinstance(encoder_outputs, TFBaseModelOutput):
|
||||
encoder_outputs = TFBaseModelOutput(
|
||||
last_hidden_state=encoder_outputs[0],
|
||||
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
|
||||
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
||||
)
|
||||
# If the user passed a TFBaseModelOutput for encoder_outputs, we wrap it in a tuple when return_dict=False
|
||||
elif not return_dict and not isinstance(encoder_outputs, tuple):
|
||||
encoder_outputs = encoder_outputs.to_tuple()
|
||||
|
||||
decoder_outputs = self.decoder(
|
||||
decoder_input_ids,
|
||||
attention_mask=decoder_attention_mask,
|
||||
encoder_hidden_states=encoder_outputs[0],
|
||||
encoder_attention_mask=attention_mask,
|
||||
head_mask=decoder_head_mask,
|
||||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||
inputs_embeds=decoder_inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
|
||||
if inputs["decoder_input_ids"] is None and inputs["decoder_inputs_embeds"] is None:
|
||||
inputs["use_cache"] = False
|
||||
|
||||
inputs["output_hidden_states"] = (
|
||||
inputs["output_hidden_states"]
|
||||
if inputs["output_hidden_states"] is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
if inputs["decoder_input_ids"] is None and inputs["input_ids"] is not None:
|
||||
inputs["decoder_input_ids"] = shift_tokens_right(inputs["input_ids"], self.config.pad_token_id)
|
||||
|
||||
if inputs["encoder_outputs"] is None:
|
||||
inputs["encoder_outputs"] = self.encoder(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
# If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True
|
||||
elif inputs["return_dict"] and not isinstance(inputs["encoder_outputs"], TFBaseModelOutput):
|
||||
inputs["encoder_outputs"] = TFBaseModelOutput(
|
||||
last_hidden_state=inputs["encoder_outputs"][0],
|
||||
hidden_states=inputs["encoder_outputs"][1] if len(inputs["encoder_outputs"]) > 1 else None,
|
||||
attentions=inputs["encoder_outputs"][2] if len(inputs["encoder_outputs"]) > 2 else None,
|
||||
)
|
||||
# If the user passed a TFBaseModelOutput for encoder_outputs, we wrap it in a tuple when return_dict=False
|
||||
elif not inputs["return_dict"] and not isinstance(inputs["encoder_outputs"], tuple):
|
||||
inputs["encoder_outputs"] = inputs["encoder_outputs"].to_tuple()
|
||||
|
||||
decoder_outputs = self.decoder(
|
||||
inputs["decoder_input_ids"],
|
||||
attention_mask=inputs["decoder_attention_mask"],
|
||||
encoder_hidden_states=inputs["encoder_outputs"][0],
|
||||
encoder_attention_mask=inputs["attention_mask"],
|
||||
head_mask=inputs["decoder_head_mask"],
|
||||
cross_attn_head_mask=inputs["cross_attn_head_mask"],
|
||||
past_key_values=inputs["past_key_values"],
|
||||
inputs_embeds=inputs["decoder_inputs_embeds"],
|
||||
use_cache=inputs["use_cache"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
return decoder_outputs + inputs["encoder_outputs"]
|
||||
if not return_dict:
|
||||
return decoder_outputs + encoder_outputs
|
||||
|
||||
return TFSeq2SeqModelOutput(
|
||||
last_hidden_state=decoder_outputs.last_hidden_state,
|
||||
@ -1182,9 +1125,9 @@ class TFMBartMainLayer(tf.keras.layers.Layer):
|
||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||
decoder_attentions=decoder_outputs.attentions,
|
||||
cross_attentions=decoder_outputs.cross_attentions,
|
||||
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state,
|
||||
encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
|
||||
encoder_attentions=inputs["encoder_outputs"].attentions,
|
||||
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
||||
encoder_hidden_states=encoder_outputs.hidden_states,
|
||||
encoder_attentions=encoder_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@ -1204,6 +1147,7 @@ class TFMBartModel(TFMBartPreTrainedModel):
|
||||
def get_decoder(self):
|
||||
return self.model.decoder
|
||||
|
||||
@unpack_inputs
|
||||
@add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_TOKENIZER_FOR_DOC,
|
||||
@ -1231,9 +1175,8 @@ class TFMBartModel(TFMBartPreTrainedModel):
|
||||
training=False,
|
||||
**kwargs
|
||||
):
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
@ -1250,26 +1193,6 @@ class TFMBartModel(TFMBartPreTrainedModel):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
|
||||
outputs = self.model(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
decoder_input_ids=inputs["decoder_input_ids"],
|
||||
decoder_attention_mask=inputs["decoder_attention_mask"],
|
||||
head_mask=inputs["head_mask"],
|
||||
decoder_head_mask=inputs["decoder_head_mask"],
|
||||
cross_attn_head_mask=inputs["cross_attn_head_mask"],
|
||||
encoder_outputs=inputs["encoder_outputs"],
|
||||
past_key_values=inputs["past_key_values"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
|
||||
use_cache=inputs["use_cache"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
return outputs
|
||||
@ -1332,6 +1255,7 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
|
||||
def set_bias(self, value):
|
||||
self.final_logits_bias = value["final_logits_bias"]
|
||||
|
||||
@unpack_inputs
|
||||
@add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||
@add_end_docstrings(MBART_GENERATION_EXAMPLE)
|
||||
@ -1365,17 +1289,26 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
|
||||
Returns:
|
||||
|
||||
"""
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
input_ids=input_ids,
|
||||
|
||||
if labels is not None:
|
||||
labels = tf.where(
|
||||
labels == self.config.pad_token_id,
|
||||
tf.fill(shape_list(labels), -100),
|
||||
labels,
|
||||
)
|
||||
use_cache = False
|
||||
if decoder_input_ids is None:
|
||||
decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id)
|
||||
|
||||
outputs = self.model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
encoder_outputs=encoder_outputs,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
decoder_head_mask=decoder_head_mask,
|
||||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||
@ -1383,44 +1316,13 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
|
||||
if inputs["labels"] is not None:
|
||||
inputs["labels"] = tf.where(
|
||||
inputs["labels"] == self.config.pad_token_id,
|
||||
tf.fill(shape_list(inputs["labels"]), -100),
|
||||
inputs["labels"],
|
||||
)
|
||||
inputs["use_cache"] = False
|
||||
if inputs["decoder_input_ids"] is None:
|
||||
inputs["decoder_input_ids"] = shift_tokens_right(inputs["labels"], self.config.pad_token_id)
|
||||
|
||||
outputs = self.model(
|
||||
inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
decoder_input_ids=inputs["decoder_input_ids"],
|
||||
encoder_outputs=inputs["encoder_outputs"],
|
||||
decoder_attention_mask=inputs["decoder_attention_mask"],
|
||||
head_mask=inputs["head_mask"],
|
||||
decoder_head_mask=inputs["decoder_head_mask"],
|
||||
cross_attn_head_mask=inputs["cross_attn_head_mask"],
|
||||
past_key_values=inputs["past_key_values"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
|
||||
use_cache=inputs["use_cache"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
lm_logits = self.model.shared(outputs[0], mode="linear")
|
||||
lm_logits = lm_logits + self.final_logits_bias
|
||||
masked_lm_loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], lm_logits)
|
||||
masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
if not return_dict:
|
||||
output = (lm_logits,) + outputs[1:]
|
||||
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
||||
return TFSeq2SeqLMOutput(
|
||||
|
Loading…
Reference in New Issue
Block a user