mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
TF: use the correct config with (...)EncoderDecoder
models (#18097)
This commit is contained in:
parent
4935409757
commit
1fc4b2a132
@ -403,8 +403,13 @@ def unpack_inputs(func):
|
||||
# move any arg into kwargs, if they exist
|
||||
fn_args_and_kwargs.update(dict(zip(func.__code__.co_varnames[1:], args)))
|
||||
|
||||
# process the inputs and call the wrapped function
|
||||
unpacked_inputs = input_processing(func, self.config, **fn_args_and_kwargs)
|
||||
# Encoder Decoder models delegate the application of the configuration options to their inner models.
|
||||
if "encoder_decoder" in str(self).lower():
|
||||
config = None
|
||||
else:
|
||||
config = self.config
|
||||
|
||||
unpacked_inputs = input_processing(func, config, **fn_args_and_kwargs)
|
||||
return func(self, **unpacked_inputs)
|
||||
|
||||
# Keras enforces the first layer argument to be passed, and checks it through `inspect.getfullargspec()`. This
|
||||
@ -559,18 +564,19 @@ def input_processing(func, config, **kwargs):
|
||||
if "kwargs" in output:
|
||||
del output["kwargs"]
|
||||
|
||||
boolean_dict = {
|
||||
k: v
|
||||
for k, v in output.items()
|
||||
if k in ["return_dict", "output_attentions", "output_hidden_states", "use_cache"]
|
||||
}
|
||||
if config is not None:
|
||||
boolean_dict = {
|
||||
k: v
|
||||
for k, v in output.items()
|
||||
if k in ["return_dict", "output_attentions", "output_hidden_states", "use_cache"]
|
||||
}
|
||||
|
||||
output.update(
|
||||
booleans_processing(
|
||||
config=config,
|
||||
**boolean_dict,
|
||||
output.update(
|
||||
booleans_processing(
|
||||
config=config,
|
||||
**boolean_dict,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
@ -630,13 +630,13 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
warnings.warn(DEPRECATION_WARNING, FutureWarning)
|
||||
loss = self.hf_compute_loss(labels, logits)
|
||||
|
||||
past_key_values = None
|
||||
if decoder_inputs["use_cache"]:
|
||||
past_key_values = decoder_outputs[1]
|
||||
# The starting index of the remaining elements in `decoder_outputs`
|
||||
start_index = sum([1 if x is not None else 0 for x in (loss, logits, past_key_values)])
|
||||
if not return_dict:
|
||||
past_key_values = None
|
||||
if use_cache:
|
||||
past_key_values = decoder_outputs[1]
|
||||
# The starting index of the remaining elements in `decoder_outputs`
|
||||
start_index = sum([1 if x is not None else 0 for x in (loss, logits, past_key_values)])
|
||||
|
||||
if not decoder_inputs["return_dict"]:
|
||||
if not isinstance(encoder_outputs, tuple):
|
||||
encoder_outputs = encoder_outputs.to_tuple()
|
||||
output = (loss, logits, past_key_values) + decoder_outputs[start_index:] + encoder_outputs
|
||||
@ -646,7 +646,7 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
return TFSeq2SeqLMOutput(
|
||||
loss=loss,
|
||||
logits=decoder_outputs.logits,
|
||||
past_key_values=past_key_values,
|
||||
past_key_values=decoder_outputs.past_key_values,
|
||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||
decoder_attentions=decoder_outputs.attentions,
|
||||
cross_attentions=decoder_outputs.cross_attentions,
|
||||
|
@ -663,13 +663,13 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
|
||||
warnings.warn(DEPRECATION_WARNING, FutureWarning)
|
||||
loss = self.hf_compute_loss(labels, logits)
|
||||
|
||||
past_key_values = None
|
||||
if decoder_inputs["use_cache"]:
|
||||
past_key_values = decoder_outputs[1]
|
||||
# The starting index of the remaining elements in `decoder_outputs`
|
||||
start_index = sum([1 if x is not None else 0 for x in (loss, logits, past_key_values)])
|
||||
if not return_dict:
|
||||
past_key_values = None
|
||||
if use_cache:
|
||||
past_key_values = decoder_outputs[1]
|
||||
# The starting index of the remaining elements in `decoder_outputs`
|
||||
start_index = sum([1 if x is not None else 0 for x in (loss, logits, past_key_values)])
|
||||
|
||||
if not decoder_inputs["return_dict"]:
|
||||
if not isinstance(encoder_outputs, tuple):
|
||||
encoder_outputs = encoder_outputs.to_tuple()
|
||||
output = (loss, logits, past_key_values) + decoder_outputs[start_index:] + encoder_outputs
|
||||
@ -679,7 +679,7 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
|
||||
return TFSeq2SeqLMOutput(
|
||||
loss=loss,
|
||||
logits=decoder_outputs.logits,
|
||||
past_key_values=past_key_values,
|
||||
past_key_values=decoder_outputs.past_key_values,
|
||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||
decoder_attentions=decoder_outputs.attentions,
|
||||
cross_attentions=decoder_outputs.cross_attentions,
|
||||
|
@ -351,32 +351,9 @@ class EncoderDecoderMixin:
|
||||
outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,))
|
||||
)
|
||||
|
||||
def check_encoder_decoder_model_output_attentions(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
labels,
|
||||
**kwargs
|
||||
def _check_output_with_attentions(
|
||||
self, outputs_encoder_decoder, config, input_ids, decoder_config, decoder_input_ids
|
||||
):
|
||||
# make the decoder inputs a different shape from the encoder inputs to harden the test
|
||||
decoder_input_ids = decoder_input_ids[:, :-1]
|
||||
decoder_attention_mask = decoder_attention_mask[:, :-1]
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
enc_dec_model.to(torch_device)
|
||||
outputs_encoder_decoder = enc_dec_model(
|
||||
input_ids=input_ids,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
output_attentions=True,
|
||||
)
|
||||
|
||||
encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
|
||||
self.assertEqual(len(encoder_attentions), config.num_hidden_layers)
|
||||
|
||||
@ -408,6 +385,85 @@ class EncoderDecoderMixin:
|
||||
(decoder_config.num_attention_heads, cross_attention_input_seq_len, input_ids.shape[-1]),
|
||||
)
|
||||
|
||||
def check_encoder_decoder_model_output_attentions(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
labels,
|
||||
**kwargs
|
||||
):
|
||||
# make the decoder inputs a different shape from the encoder inputs to harden the test
|
||||
decoder_input_ids = decoder_input_ids[:, :-1]
|
||||
decoder_attention_mask = decoder_attention_mask[:, :-1]
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
enc_dec_model.to(torch_device)
|
||||
outputs_encoder_decoder = enc_dec_model(
|
||||
input_ids=input_ids,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
output_attentions=True,
|
||||
)
|
||||
self._check_output_with_attentions(
|
||||
outputs_encoder_decoder, config, input_ids, decoder_config, decoder_input_ids
|
||||
)
|
||||
|
||||
def check_encoder_decoder_model_output_attentions_from_config(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
labels,
|
||||
**kwargs
|
||||
):
|
||||
# Similar to `check_encoder_decoder_model_output_attentions`, but with `output_attentions` triggered from the
|
||||
# config file. Contrarily to most models, changing the model's config won't work -- the defaults are loaded
|
||||
# from the inner models' configurations.
|
||||
|
||||
decoder_input_ids = decoder_input_ids[:, :-1]
|
||||
decoder_attention_mask = decoder_attention_mask[:, :-1]
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
enc_dec_model.config.output_attentions = True # model config -> won't work
|
||||
enc_dec_model.to(torch_device)
|
||||
outputs_encoder_decoder = enc_dec_model(
|
||||
input_ids=input_ids,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
)
|
||||
self.assertTrue(
|
||||
all(
|
||||
key not in outputs_encoder_decoder
|
||||
for key in ["encoder_attentions", "decoder_attentions", "cross_attentions"]
|
||||
)
|
||||
)
|
||||
|
||||
config.output_attentions = True # inner model config -> will work
|
||||
decoder_config.output_attentions = True
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
enc_dec_model.to(torch_device)
|
||||
outputs_encoder_decoder = enc_dec_model(
|
||||
input_ids=input_ids,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
)
|
||||
self._check_output_with_attentions(
|
||||
outputs_encoder_decoder, config, input_ids, decoder_config, decoder_input_ids
|
||||
)
|
||||
|
||||
def check_encoder_decoder_model_generate(self, input_ids, config, decoder_config, **kwargs):
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
@ -543,6 +599,10 @@ class EncoderDecoderMixin:
|
||||
input_ids_dict = self.prepare_config_and_inputs()
|
||||
self.check_encoder_decoder_model_output_attentions(**input_ids_dict)
|
||||
|
||||
def test_encoder_decoder_model_output_attentions_from_config(self):
|
||||
input_ids_dict = self.prepare_config_and_inputs()
|
||||
self.check_encoder_decoder_model_output_attentions_from_config(**input_ids_dict)
|
||||
|
||||
def test_encoder_decoder_model_generate(self):
|
||||
input_ids_dict = self.prepare_config_and_inputs()
|
||||
self.check_encoder_decoder_model_generate(**input_ids_dict)
|
||||
|
@ -255,31 +255,9 @@ class TFEncoderDecoderMixin:
|
||||
outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,))
|
||||
)
|
||||
|
||||
def check_encoder_decoder_model_output_attentions(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
**kwargs
|
||||
def _check_output_with_attentions(
|
||||
self, outputs_encoder_decoder, config, input_ids, decoder_config, decoder_input_ids
|
||||
):
|
||||
# make the decoder inputs a different shape from the encoder inputs to harden the test
|
||||
decoder_input_ids = decoder_input_ids[:, :-1]
|
||||
decoder_attention_mask = decoder_attention_mask[:, :-1]
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
enc_dec_model = TFEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
outputs_encoder_decoder = enc_dec_model(
|
||||
input_ids=input_ids,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
output_attentions=True,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
|
||||
self.assertEqual(len(encoder_attentions), config.num_hidden_layers)
|
||||
|
||||
@ -311,6 +289,83 @@ class TFEncoderDecoderMixin:
|
||||
(decoder_config.num_attention_heads, cross_attention_input_seq_len, input_ids.shape[-1]),
|
||||
)
|
||||
|
||||
def check_encoder_decoder_model_output_attentions(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
**kwargs
|
||||
):
|
||||
# make the decoder inputs a different shape from the encoder inputs to harden the test
|
||||
decoder_input_ids = decoder_input_ids[:, :-1]
|
||||
decoder_attention_mask = decoder_attention_mask[:, :-1]
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
enc_dec_model = TFEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
outputs_encoder_decoder = enc_dec_model(
|
||||
input_ids=input_ids,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
output_attentions=True,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
self._check_output_with_attentions(
|
||||
outputs_encoder_decoder, config, input_ids, decoder_config, decoder_input_ids
|
||||
)
|
||||
|
||||
def check_encoder_decoder_model_output_attentions_from_config(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
**kwargs
|
||||
):
|
||||
# Similar to `check_encoder_decoder_model_output_attentions`, but with `output_attentions` triggered from the
|
||||
# config file. Contrarily to most models, changing the model's config won't work -- the defaults are loaded
|
||||
# from the inner models' configurations.
|
||||
|
||||
decoder_input_ids = decoder_input_ids[:, :-1]
|
||||
decoder_attention_mask = decoder_attention_mask[:, :-1]
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
enc_dec_model = TFEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
enc_dec_model.config.output_attentions = True # model config -> won't work
|
||||
outputs_encoder_decoder = enc_dec_model(
|
||||
input_ids=input_ids,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
self.assertTrue(
|
||||
all(
|
||||
key not in outputs_encoder_decoder
|
||||
for key in ["encoder_attentions", "decoder_attentions", "cross_attentions"]
|
||||
)
|
||||
)
|
||||
|
||||
config.output_attentions = True # inner model config -> will work
|
||||
decoder_config.output_attentions = True
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
enc_dec_model = TFEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
outputs_encoder_decoder = enc_dec_model(
|
||||
input_ids=input_ids,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
self._check_output_with_attentions(
|
||||
outputs_encoder_decoder, config, input_ids, decoder_config, decoder_input_ids
|
||||
)
|
||||
|
||||
def check_encoder_decoder_model_generate(self, input_ids, config, decoder_config, **kwargs):
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
enc_dec_model = TFEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
@ -570,6 +625,10 @@ class TFEncoderDecoderMixin:
|
||||
input_ids_dict = self.prepare_config_and_inputs()
|
||||
self.check_encoder_decoder_model_output_attentions(**input_ids_dict)
|
||||
|
||||
def test_encoder_decoder_model_output_attentions_from_config(self):
|
||||
input_ids_dict = self.prepare_config_and_inputs()
|
||||
self.check_encoder_decoder_model_output_attentions_from_config(**input_ids_dict)
|
||||
|
||||
def test_encoder_decoder_model_generate(self):
|
||||
input_ids_dict = self.prepare_config_and_inputs()
|
||||
self.check_encoder_decoder_model_generate(**input_ids_dict)
|
||||
|
Loading…
Reference in New Issue
Block a user