mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
TF: properly handle kwargs in encoder_decoder architectures (#16465)
* properly handle kwargs in encoder_decoder architectures * make fixup
This commit is contained in:
parent
0540d1b6c0
commit
7a9ef8181c
@ -569,13 +569,12 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
"output_hidden_states": output_hidden_states,
|
||||
"return_dict": return_dict,
|
||||
"training": training,
|
||||
"kwargs_call": kwargs_encoder,
|
||||
"kwargs_call": {},
|
||||
}
|
||||
|
||||
# Add arguments to encoder from `kwargs_encoder`
|
||||
for k, v in kwargs_encoder.items():
|
||||
encoder_processing_inputs[k] = v
|
||||
kwargs_encoder = {}
|
||||
|
||||
encoder_inputs = input_processing(**encoder_processing_inputs)
|
||||
|
||||
@ -622,13 +621,12 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
"past_key_values": past_key_values,
|
||||
"return_dict": return_dict,
|
||||
"training": training,
|
||||
"kwargs_call": kwargs_decoder,
|
||||
"kwargs_call": {},
|
||||
}
|
||||
|
||||
# Add arguments to decoder from `kwargs_decoder`
|
||||
for k, v in kwargs_decoder.items():
|
||||
decoder_processing_inputs[k] = v
|
||||
kwargs_decoder = {}
|
||||
|
||||
decoder_inputs = input_processing(**decoder_processing_inputs)
|
||||
decoder_outputs = self.decoder(**decoder_inputs)
|
||||
|
@ -593,12 +593,11 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
|
||||
"output_hidden_states": output_hidden_states,
|
||||
"return_dict": return_dict,
|
||||
"training": training,
|
||||
"kwargs_call": kwargs_encoder,
|
||||
"kwargs_call": {},
|
||||
}
|
||||
|
||||
# Add arguments to encoder from `kwargs_encoder`
|
||||
encoder_processing_inputs.update(kwargs_encoder)
|
||||
kwargs_encoder = {}
|
||||
|
||||
encoder_inputs = input_processing(**encoder_processing_inputs)
|
||||
|
||||
@ -654,12 +653,11 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
|
||||
"past_key_values": past_key_values,
|
||||
"return_dict": return_dict,
|
||||
"training": training,
|
||||
"kwargs_call": kwargs_decoder,
|
||||
"kwargs_call": {},
|
||||
}
|
||||
|
||||
# Add arguments to decoder from `kwargs_decoder`
|
||||
decoder_processing_inputs.update(kwargs_decoder)
|
||||
kwargs_decoder = {}
|
||||
|
||||
decoder_inputs = input_processing(**decoder_processing_inputs)
|
||||
decoder_outputs = self.decoder(**decoder_inputs)
|
||||
|
@ -91,6 +91,7 @@ class TFEncoderDecoderMixin:
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
@ -122,6 +123,7 @@ class TFEncoderDecoderMixin:
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
self.assertEqual(
|
||||
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
|
||||
@ -137,6 +139,7 @@ class TFEncoderDecoderMixin:
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
@ -167,6 +170,7 @@ class TFEncoderDecoderMixin:
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
return_dict=True,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
@ -195,6 +199,7 @@ class TFEncoderDecoderMixin:
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
out_2 = np.array(outputs[0])
|
||||
out_2[np.isnan(out_2)] = 0
|
||||
@ -208,6 +213,7 @@ class TFEncoderDecoderMixin:
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
out_1 = np.array(after_outputs[0])
|
||||
out_1[np.isnan(out_1)] = 0
|
||||
@ -235,6 +241,7 @@ class TFEncoderDecoderMixin:
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
labels=labels,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
# Make sure `loss` exist
|
||||
@ -269,6 +276,7 @@ class TFEncoderDecoderMixin:
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
output_attentions=True,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
|
||||
|
@ -96,6 +96,7 @@ class TFVisionEncoderDecoderMixin:
|
||||
pixel_values=pixel_values,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
@ -124,6 +125,7 @@ class TFVisionEncoderDecoderMixin:
|
||||
pixel_values=pixel_values,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
self.assertEqual(
|
||||
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
|
||||
@ -137,6 +139,7 @@ class TFVisionEncoderDecoderMixin:
|
||||
encoder_outputs=encoder_outputs,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
@ -164,6 +167,7 @@ class TFVisionEncoderDecoderMixin:
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
return_dict=True,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
@ -189,6 +193,7 @@ class TFVisionEncoderDecoderMixin:
|
||||
pixel_values=pixel_values,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
out_2 = np.array(outputs[0])
|
||||
out_2[np.isnan(out_2)] = 0
|
||||
@ -201,6 +206,7 @@ class TFVisionEncoderDecoderMixin:
|
||||
pixel_values=pixel_values,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
out_1 = np.array(after_outputs[0])
|
||||
out_1[np.isnan(out_1)] = 0
|
||||
@ -226,6 +232,7 @@ class TFVisionEncoderDecoderMixin:
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
labels=labels,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
# Make sure `loss` exist
|
||||
@ -257,6 +264,7 @@ class TFVisionEncoderDecoderMixin:
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
output_attentions=True,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
|
||||
|
Loading…
Reference in New Issue
Block a user