TF: use the correct config with (...)EncoderDecoder models (#18097)

This commit is contained in:
Joao Gante 2022-07-22 13:31:45 +01:00 committed by GitHub
parent 4935409757
commit 1fc4b2a132
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 200 additions and 75 deletions

View File

@ -403,8 +403,13 @@ def unpack_inputs(func):
# move any arg into kwargs, if they exist
fn_args_and_kwargs.update(dict(zip(func.__code__.co_varnames[1:], args)))
# process the inputs and call the wrapped function
unpacked_inputs = input_processing(func, self.config, **fn_args_and_kwargs)
# Encoder Decoder models delegate the application of the configuration options to their inner models.
if "encoder_decoder" in str(self).lower():
config = None
else:
config = self.config
unpacked_inputs = input_processing(func, config, **fn_args_and_kwargs)
return func(self, **unpacked_inputs)
# Keras enforces the first layer argument to be passed, and checks it through `inspect.getfullargspec()`. This
@ -559,18 +564,19 @@ def input_processing(func, config, **kwargs):
if "kwargs" in output:
del output["kwargs"]
boolean_dict = {
k: v
for k, v in output.items()
if k in ["return_dict", "output_attentions", "output_hidden_states", "use_cache"]
}
if config is not None:
boolean_dict = {
k: v
for k, v in output.items()
if k in ["return_dict", "output_attentions", "output_hidden_states", "use_cache"]
}
output.update(
booleans_processing(
config=config,
**boolean_dict,
output.update(
booleans_processing(
config=config,
**boolean_dict,
)
)
)
return output

View File

@ -630,13 +630,13 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
warnings.warn(DEPRECATION_WARNING, FutureWarning)
loss = self.hf_compute_loss(labels, logits)
past_key_values = None
if decoder_inputs["use_cache"]:
past_key_values = decoder_outputs[1]
# The starting index of the remaining elements in `decoder_outputs`
start_index = sum([1 if x is not None else 0 for x in (loss, logits, past_key_values)])
if not return_dict:
past_key_values = None
if use_cache:
past_key_values = decoder_outputs[1]
# The starting index of the remaining elements in `decoder_outputs`
start_index = sum([1 if x is not None else 0 for x in (loss, logits, past_key_values)])
if not decoder_inputs["return_dict"]:
if not isinstance(encoder_outputs, tuple):
encoder_outputs = encoder_outputs.to_tuple()
output = (loss, logits, past_key_values) + decoder_outputs[start_index:] + encoder_outputs
@ -646,7 +646,7 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
return TFSeq2SeqLMOutput(
loss=loss,
logits=decoder_outputs.logits,
past_key_values=past_key_values,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,

View File

@ -663,13 +663,13 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
warnings.warn(DEPRECATION_WARNING, FutureWarning)
loss = self.hf_compute_loss(labels, logits)
past_key_values = None
if decoder_inputs["use_cache"]:
past_key_values = decoder_outputs[1]
# The starting index of the remaining elements in `decoder_outputs`
start_index = sum([1 if x is not None else 0 for x in (loss, logits, past_key_values)])
if not return_dict:
past_key_values = None
if use_cache:
past_key_values = decoder_outputs[1]
# The starting index of the remaining elements in `decoder_outputs`
start_index = sum([1 if x is not None else 0 for x in (loss, logits, past_key_values)])
if not decoder_inputs["return_dict"]:
if not isinstance(encoder_outputs, tuple):
encoder_outputs = encoder_outputs.to_tuple()
output = (loss, logits, past_key_values) + decoder_outputs[start_index:] + encoder_outputs
@ -679,7 +679,7 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
return TFSeq2SeqLMOutput(
loss=loss,
logits=decoder_outputs.logits,
past_key_values=past_key_values,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,

View File

@ -351,32 +351,9 @@ class EncoderDecoderMixin:
outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,))
)
def check_encoder_decoder_model_output_attentions(
self,
config,
input_ids,
attention_mask,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
labels,
**kwargs
def _check_output_with_attentions(
self, outputs_encoder_decoder, config, input_ids, decoder_config, decoder_input_ids
):
# make the decoder inputs a different shape from the encoder inputs to harden the test
decoder_input_ids = decoder_input_ids[:, :-1]
decoder_attention_mask = decoder_attention_mask[:, :-1]
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
enc_dec_model.to(torch_device)
outputs_encoder_decoder = enc_dec_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
output_attentions=True,
)
encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
self.assertEqual(len(encoder_attentions), config.num_hidden_layers)
@ -408,6 +385,85 @@ class EncoderDecoderMixin:
(decoder_config.num_attention_heads, cross_attention_input_seq_len, input_ids.shape[-1]),
)
def check_encoder_decoder_model_output_attentions(
self,
config,
input_ids,
attention_mask,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
labels,
**kwargs
):
# make the decoder inputs a different shape from the encoder inputs to harden the test
decoder_input_ids = decoder_input_ids[:, :-1]
decoder_attention_mask = decoder_attention_mask[:, :-1]
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
enc_dec_model.to(torch_device)
outputs_encoder_decoder = enc_dec_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
output_attentions=True,
)
self._check_output_with_attentions(
outputs_encoder_decoder, config, input_ids, decoder_config, decoder_input_ids
)
def check_encoder_decoder_model_output_attentions_from_config(
self,
config,
input_ids,
attention_mask,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
labels,
**kwargs
):
# Similar to `check_encoder_decoder_model_output_attentions`, but with `output_attentions` triggered from the
# config file. Contrarily to most models, changing the model's config won't work -- the defaults are loaded
# from the inner models' configurations.
decoder_input_ids = decoder_input_ids[:, :-1]
decoder_attention_mask = decoder_attention_mask[:, :-1]
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
enc_dec_model.config.output_attentions = True # model config -> won't work
enc_dec_model.to(torch_device)
outputs_encoder_decoder = enc_dec_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
self.assertTrue(
all(
key not in outputs_encoder_decoder
for key in ["encoder_attentions", "decoder_attentions", "cross_attentions"]
)
)
config.output_attentions = True # inner model config -> will work
decoder_config.output_attentions = True
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
enc_dec_model.to(torch_device)
outputs_encoder_decoder = enc_dec_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
self._check_output_with_attentions(
outputs_encoder_decoder, config, input_ids, decoder_config, decoder_input_ids
)
def check_encoder_decoder_model_generate(self, input_ids, config, decoder_config, **kwargs):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
@ -543,6 +599,10 @@ class EncoderDecoderMixin:
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_output_attentions(**input_ids_dict)
def test_encoder_decoder_model_output_attentions_from_config(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_output_attentions_from_config(**input_ids_dict)
def test_encoder_decoder_model_generate(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_generate(**input_ids_dict)

View File

@ -255,31 +255,9 @@ class TFEncoderDecoderMixin:
outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,))
)
def check_encoder_decoder_model_output_attentions(
self,
config,
input_ids,
attention_mask,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
**kwargs
def _check_output_with_attentions(
self, outputs_encoder_decoder, config, input_ids, decoder_config, decoder_input_ids
):
# make the decoder inputs a different shape from the encoder inputs to harden the test
decoder_input_ids = decoder_input_ids[:, :-1]
decoder_attention_mask = decoder_attention_mask[:, :-1]
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = TFEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
outputs_encoder_decoder = enc_dec_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
output_attentions=True,
kwargs=kwargs,
)
encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
self.assertEqual(len(encoder_attentions), config.num_hidden_layers)
@ -311,6 +289,83 @@ class TFEncoderDecoderMixin:
(decoder_config.num_attention_heads, cross_attention_input_seq_len, input_ids.shape[-1]),
)
def check_encoder_decoder_model_output_attentions(
self,
config,
input_ids,
attention_mask,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
**kwargs
):
# make the decoder inputs a different shape from the encoder inputs to harden the test
decoder_input_ids = decoder_input_ids[:, :-1]
decoder_attention_mask = decoder_attention_mask[:, :-1]
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = TFEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
outputs_encoder_decoder = enc_dec_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
output_attentions=True,
kwargs=kwargs,
)
self._check_output_with_attentions(
outputs_encoder_decoder, config, input_ids, decoder_config, decoder_input_ids
)
def check_encoder_decoder_model_output_attentions_from_config(
self,
config,
input_ids,
attention_mask,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
**kwargs
):
# Similar to `check_encoder_decoder_model_output_attentions`, but with `output_attentions` triggered from the
# config file. Contrarily to most models, changing the model's config won't work -- the defaults are loaded
# from the inner models' configurations.
decoder_input_ids = decoder_input_ids[:, :-1]
decoder_attention_mask = decoder_attention_mask[:, :-1]
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = TFEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
enc_dec_model.config.output_attentions = True # model config -> won't work
outputs_encoder_decoder = enc_dec_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
kwargs=kwargs,
)
self.assertTrue(
all(
key not in outputs_encoder_decoder
for key in ["encoder_attentions", "decoder_attentions", "cross_attentions"]
)
)
config.output_attentions = True # inner model config -> will work
decoder_config.output_attentions = True
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = TFEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
outputs_encoder_decoder = enc_dec_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
kwargs=kwargs,
)
self._check_output_with_attentions(
outputs_encoder_decoder, config, input_ids, decoder_config, decoder_input_ids
)
def check_encoder_decoder_model_generate(self, input_ids, config, decoder_config, **kwargs):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = TFEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
@ -570,6 +625,10 @@ class TFEncoderDecoderMixin:
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_output_attentions(**input_ids_dict)
def test_encoder_decoder_model_output_attentions_from_config(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_output_attentions_from_config(**input_ids_dict)
def test_encoder_decoder_model_generate(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_generate(**input_ids_dict)