TF: properly handle kwargs in encoder_decoder architectures (#16465)

* properly handle kwargs in encoder_decoder architectures

* make fixup
This commit is contained in:
Joao Gante 2022-03-29 18:17:47 +01:00 committed by GitHub
parent 0540d1b6c0
commit 7a9ef8181c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 20 additions and 8 deletions

View File

@ -569,13 +569,12 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
"output_hidden_states": output_hidden_states,
"return_dict": return_dict,
"training": training,
"kwargs_call": kwargs_encoder,
"kwargs_call": {},
}
# Add arguments to encoder from `kwargs_encoder`
for k, v in kwargs_encoder.items():
encoder_processing_inputs[k] = v
kwargs_encoder = {}
encoder_inputs = input_processing(**encoder_processing_inputs)
@ -622,13 +621,12 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
"past_key_values": past_key_values,
"return_dict": return_dict,
"training": training,
"kwargs_call": kwargs_decoder,
"kwargs_call": {},
}
# Add arguments to decoder from `kwargs_decoder`
for k, v in kwargs_decoder.items():
decoder_processing_inputs[k] = v
kwargs_decoder = {}
decoder_inputs = input_processing(**decoder_processing_inputs)
decoder_outputs = self.decoder(**decoder_inputs)

View File

@ -593,12 +593,11 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
"output_hidden_states": output_hidden_states,
"return_dict": return_dict,
"training": training,
"kwargs_call": kwargs_encoder,
"kwargs_call": {},
}
# Add arguments to encoder from `kwargs_encoder`
encoder_processing_inputs.update(kwargs_encoder)
kwargs_encoder = {}
encoder_inputs = input_processing(**encoder_processing_inputs)
@ -654,12 +653,11 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
"past_key_values": past_key_values,
"return_dict": return_dict,
"training": training,
"kwargs_call": kwargs_decoder,
"kwargs_call": {},
}
# Add arguments to decoder from `kwargs_decoder`
decoder_processing_inputs.update(kwargs_decoder)
kwargs_decoder = {}
decoder_inputs = input_processing(**decoder_processing_inputs)
decoder_outputs = self.decoder(**decoder_inputs)

View File

@ -91,6 +91,7 @@ class TFEncoderDecoderMixin:
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
kwargs=kwargs,
)
self.assertEqual(
@ -122,6 +123,7 @@ class TFEncoderDecoderMixin:
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
kwargs=kwargs,
)
self.assertEqual(
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
@ -137,6 +139,7 @@ class TFEncoderDecoderMixin:
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
kwargs=kwargs,
)
self.assertEqual(
@ -167,6 +170,7 @@ class TFEncoderDecoderMixin:
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
return_dict=True,
kwargs=kwargs,
)
self.assertEqual(
@ -195,6 +199,7 @@ class TFEncoderDecoderMixin:
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
kwargs=kwargs,
)
out_2 = np.array(outputs[0])
out_2[np.isnan(out_2)] = 0
@ -208,6 +213,7 @@ class TFEncoderDecoderMixin:
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
kwargs=kwargs,
)
out_1 = np.array(after_outputs[0])
out_1[np.isnan(out_1)] = 0
@ -235,6 +241,7 @@ class TFEncoderDecoderMixin:
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
labels=labels,
kwargs=kwargs,
)
# Make sure `loss` exist
@ -269,6 +276,7 @@ class TFEncoderDecoderMixin:
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
output_attentions=True,
kwargs=kwargs,
)
encoder_attentions = outputs_encoder_decoder["encoder_attentions"]

View File

@ -96,6 +96,7 @@ class TFVisionEncoderDecoderMixin:
pixel_values=pixel_values,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
kwargs=kwargs,
)
self.assertEqual(
@ -124,6 +125,7 @@ class TFVisionEncoderDecoderMixin:
pixel_values=pixel_values,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
kwargs=kwargs,
)
self.assertEqual(
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
@ -137,6 +139,7 @@ class TFVisionEncoderDecoderMixin:
encoder_outputs=encoder_outputs,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
kwargs=kwargs,
)
self.assertEqual(
@ -164,6 +167,7 @@ class TFVisionEncoderDecoderMixin:
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
return_dict=True,
kwargs=kwargs,
)
self.assertEqual(
@ -189,6 +193,7 @@ class TFVisionEncoderDecoderMixin:
pixel_values=pixel_values,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
kwargs=kwargs,
)
out_2 = np.array(outputs[0])
out_2[np.isnan(out_2)] = 0
@ -201,6 +206,7 @@ class TFVisionEncoderDecoderMixin:
pixel_values=pixel_values,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
kwargs=kwargs,
)
out_1 = np.array(after_outputs[0])
out_1[np.isnan(out_1)] = 0
@ -226,6 +232,7 @@ class TFVisionEncoderDecoderMixin:
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
labels=labels,
kwargs=kwargs,
)
# Make sure `loss` exist
@ -257,6 +264,7 @@ class TFVisionEncoderDecoderMixin:
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
output_attentions=True,
kwargs=kwargs,
)
encoder_attentions = outputs_encoder_decoder["encoder_attentions"]