mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[Use cache] Align logic of use_cache
with output_attentions and output_hidden_states (#5194)
* fix use cache * add bart use cache * fix bart * finish bart
This commit is contained in:
parent
64c393ee74
commit
c2a26ec8a6
@ -815,14 +815,19 @@ class BartModel(PretrainedBartModel):
|
||||
encoder_outputs: Optional[Tuple] = None,
|
||||
decoder_attention_mask=None,
|
||||
decoder_cached_states=None,
|
||||
use_cache=False,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
):
|
||||
|
||||
if decoder_input_ids is None:
|
||||
use_cache = False
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
# make masks if user doesn't supply
|
||||
if not use_cache:
|
||||
@ -915,7 +920,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
decoder_attention_mask=None,
|
||||
decoder_cached_states=None,
|
||||
labels=None,
|
||||
use_cache=False,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
**unused,
|
||||
@ -968,6 +973,9 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
)
|
||||
labels = unused.pop("lm_labels")
|
||||
|
||||
if labels is not None:
|
||||
use_cache = False
|
||||
|
||||
outputs = self.model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
@ -1070,6 +1078,7 @@ class BartForSequenceClassification(PretrainedBartModel):
|
||||
labels=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
use_cache=None,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
||||
@ -1106,6 +1115,9 @@ class BartForSequenceClassification(PretrainedBartModel):
|
||||
loss, logits = outputs[:2]
|
||||
|
||||
"""
|
||||
if labels is not None:
|
||||
use_cache = False
|
||||
|
||||
outputs = self.model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
@ -1114,6 +1126,7 @@ class BartForSequenceClassification(PretrainedBartModel):
|
||||
encoder_outputs=encoder_outputs,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
x = outputs[0] # last hidden state
|
||||
eos_mask = input_ids.eq(self.config.eos_token_id)
|
||||
@ -1159,6 +1172,7 @@ class BartForQuestionAnswering(PretrainedBartModel):
|
||||
end_positions=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
use_cache=None,
|
||||
):
|
||||
r"""
|
||||
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
||||
@ -1206,6 +1220,8 @@ class BartForQuestionAnswering(PretrainedBartModel):
|
||||
answer = ' '.join(all_tokens[torch.argmax(start_scores) : torch.argmax(end_scores)+1])
|
||||
|
||||
"""
|
||||
if start_positions is not None and end_positions is not None:
|
||||
use_cache = False
|
||||
|
||||
outputs = self.model(
|
||||
input_ids,
|
||||
@ -1215,6 +1231,7 @@ class BartForQuestionAnswering(PretrainedBartModel):
|
||||
encoder_outputs=encoder_outputs,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
|
@ -335,7 +335,7 @@ class CTRLModel(CTRLPreTrainedModel):
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
use_cache=True,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
):
|
||||
@ -374,6 +374,7 @@ class CTRLModel(CTRLPreTrainedModel):
|
||||
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
@ -519,7 +520,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
use_cache=True,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
):
|
||||
|
@ -379,7 +379,7 @@ class GPT2Model(GPT2PreTrainedModel):
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
use_cache=True,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
):
|
||||
@ -420,6 +420,7 @@ class GPT2Model(GPT2PreTrainedModel):
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
@ -562,7 +563,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
use_cache=True,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
):
|
||||
@ -671,7 +672,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
|
||||
mc_token_ids=None,
|
||||
labels=None,
|
||||
mc_labels=None,
|
||||
use_cache=True,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
**kwargs
|
||||
|
@ -659,11 +659,12 @@ class T5Stack(T5PreTrainedModel):
|
||||
inputs_embeds=None,
|
||||
head_mask=None,
|
||||
past_key_value_states=None,
|
||||
use_cache=False,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
):
|
||||
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
@ -854,6 +855,7 @@ class T5Model(T5PreTrainedModel):
|
||||
self.shared = nn.Embedding(config.vocab_size, config.d_model)
|
||||
|
||||
encoder_config = copy.deepcopy(config)
|
||||
encoder_config.use_cache = False
|
||||
self.encoder = T5Stack(encoder_config, self.shared)
|
||||
|
||||
decoder_config = copy.deepcopy(config)
|
||||
@ -893,7 +895,7 @@ class T5Model(T5PreTrainedModel):
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
decoder_past_key_value_states=None,
|
||||
use_cache=True,
|
||||
use_cache=None,
|
||||
inputs_embeds=None,
|
||||
decoder_inputs_embeds=None,
|
||||
head_mask=None,
|
||||
@ -933,6 +935,7 @@ class T5Model(T5PreTrainedModel):
|
||||
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
||||
|
||||
"""
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
# Encode if needed (training, first prediction pass)
|
||||
if encoder_outputs is None:
|
||||
@ -985,6 +988,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
self.shared = nn.Embedding(config.vocab_size, config.d_model)
|
||||
|
||||
encoder_config = copy.deepcopy(config)
|
||||
encoder_config.use_cache = False
|
||||
self.encoder = T5Stack(encoder_config, self.shared)
|
||||
|
||||
decoder_config = copy.deepcopy(config)
|
||||
@ -1021,7 +1025,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
decoder_past_key_value_states=None,
|
||||
use_cache=True,
|
||||
use_cache=None,
|
||||
labels=None,
|
||||
inputs_embeds=None,
|
||||
decoder_inputs_embeds=None,
|
||||
@ -1086,6 +1090,8 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
labels = kwargs.pop("lm_labels")
|
||||
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
|
||||
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
# Encode if needed (training, first prediction pass)
|
||||
if encoder_outputs is None:
|
||||
# Convert encoder inputs in embeddings if needed
|
||||
|
@ -186,6 +186,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
||||
super().__init__(**kwargs)
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
self.output_attentions = config.output_attentions
|
||||
self.use_cache = config.use_cache
|
||||
|
||||
self.d_model_size = config.n_embd
|
||||
self.num_layers = config.n_layer
|
||||
@ -235,7 +236,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
use_cache=True,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
training=False,
|
||||
@ -270,6 +271,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
|
||||
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
|
||||
use_cache = use_cache if use_cache is not None else self.use_cache
|
||||
|
||||
# If using past key value states, only the last tokens
|
||||
# should be given as an input
|
||||
|
@ -215,6 +215,8 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
|
||||
super().__init__(*inputs, **kwargs)
|
||||
self.output_attentions = config.output_attentions
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
self.use_cache = config.use_cache
|
||||
|
||||
self.num_hidden_layers = config.n_layer
|
||||
self.vocab_size = config.vocab_size
|
||||
self.n_embd = config.n_embd
|
||||
@ -254,10 +256,10 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
use_cache=True,
|
||||
training=False,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
training=False,
|
||||
):
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
@ -288,6 +290,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
|
||||
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
|
||||
use_cache = use_cache if use_cache is not None else self.use_cache
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
@ -622,7 +625,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
mc_token_ids=None,
|
||||
use_cache=True,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
training=False,
|
||||
|
@ -518,6 +518,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
||||
super().__init__(**kwargs)
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
self.output_attentions = config.output_attentions
|
||||
self.use_cache = config.use_cache
|
||||
|
||||
self.embed_tokens = embed_tokens
|
||||
self.is_decoder = config.is_decoder
|
||||
@ -556,7 +557,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
||||
inputs_embeds=None,
|
||||
head_mask=None,
|
||||
past_key_value_states=None,
|
||||
use_cache=False,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
training=False,
|
||||
@ -586,6 +587,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
|
||||
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
|
||||
use_cache = use_cache if use_cache is not None else self.use_cache
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both inputs and inputs_embeds at the same time")
|
||||
@ -874,6 +876,7 @@ class TFT5Model(TFT5PreTrainedModel):
|
||||
embed_tokens = _NoLayerEmbedTokens(self.shared, abs_scope_name=shared_abs_scope_name)
|
||||
|
||||
encoder_config = copy.deepcopy(config)
|
||||
encoder_config.use_cache = False
|
||||
self.encoder = TFT5MainLayer(encoder_config, embed_tokens, name="encoder")
|
||||
|
||||
decoder_config = copy.deepcopy(config)
|
||||
@ -952,11 +955,13 @@ class TFT5Model(TFT5PreTrainedModel):
|
||||
decoder_attention_mask = kwargs.get("decoder_attention_mask", None)
|
||||
decoder_inputs_embeds = kwargs.get("decoder_inputs_embeds", None)
|
||||
decoder_past_key_value_states = kwargs.get("decoder_past_key_value_states", None)
|
||||
use_cache = kwargs.get("use_cache", True)
|
||||
use_cache = kwargs.get("use_cache", None)
|
||||
head_mask = kwargs.get("head_mask", None)
|
||||
output_attentions = kwargs.get("output_attentions", None)
|
||||
output_hidden_states = kwargs.get("output_hidden_states", None)
|
||||
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
# Encode if needed (training, first prediction pass)
|
||||
if encoder_outputs is None:
|
||||
encoder_outputs = self.encoder(
|
||||
@ -1014,6 +1019,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel):
|
||||
embed_tokens = _NoLayerEmbedTokens(self.shared, abs_scope_name=shared_abs_scope_name)
|
||||
|
||||
encoder_config = copy.deepcopy(config)
|
||||
encoder_config.use_cache = False
|
||||
self.encoder = TFT5MainLayer(encoder_config, embed_tokens, name="encoder")
|
||||
|
||||
decoder_config = copy.deepcopy(config)
|
||||
@ -1095,13 +1101,15 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel):
|
||||
encoder_outputs = kwargs.get("encoder_outputs", None)
|
||||
decoder_attention_mask = kwargs.get("decoder_attention_mask", None)
|
||||
decoder_past_key_value_states = kwargs.get("decoder_past_key_value_states", None)
|
||||
use_cache = kwargs.get("use_cache", True)
|
||||
use_cache = kwargs.get("use_cache", None)
|
||||
inputs_embeds = kwargs.get("inputs_embeds", None)
|
||||
decoder_inputs_embeds = kwargs.get("decoder_inputs_embeds", None)
|
||||
head_mask = kwargs.get("head_mask", None)
|
||||
output_attentions = kwargs.get("output_attentions", None)
|
||||
output_hidden_states = kwargs.get("output_hidden_states", None)
|
||||
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
# Encode if needed (training, first prediction pass)
|
||||
if encoder_outputs is None:
|
||||
# Convert encoder inputs in embeddings if needed
|
||||
|
@ -153,6 +153,7 @@ class BARTModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
def test_advanced_inputs(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.use_cache = False
|
||||
inputs_dict["input_ids"][:, -2:] = config.pad_token_id
|
||||
decoder_input_ids, decoder_attn_mask, causal_mask = _prepare_bart_decoder_inputs(
|
||||
config, inputs_dict["input_ids"]
|
||||
|
@ -168,7 +168,14 @@ class GPT2ModelTester:
|
||||
model.eval()
|
||||
|
||||
# first forward pass
|
||||
output, past = model(input_ids, token_type_ids=token_type_ids)
|
||||
outputs = model(input_ids, token_type_ids=token_type_ids, use_cache=True)
|
||||
outputs_use_cache_conf = model(input_ids, token_type_ids=token_type_ids)
|
||||
outputs_no_past = model(input_ids, token_type_ids=token_type_ids, use_cache=False)
|
||||
|
||||
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
|
||||
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
|
||||
|
||||
output, past = outputs
|
||||
|
||||
# create hypothetical next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
||||
|
@ -193,7 +193,14 @@ class T5ModelTester:
|
||||
model.eval()
|
||||
|
||||
# first forward pass
|
||||
output, past_key_value_states = model(input_ids, use_cache=True)
|
||||
outputs = model(input_ids, use_cache=True)
|
||||
outputs_use_cache_conf = model(input_ids)
|
||||
outputs_no_past = model(input_ids, use_cache=False)
|
||||
|
||||
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
|
||||
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
|
||||
|
||||
output, past_key_value_states = outputs
|
||||
|
||||
# create hypothetical next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
||||
|
@ -126,6 +126,7 @@ class TFModelTesterMixin:
|
||||
if "T5" in main_layer_class.__name__:
|
||||
# Take the same values than in TFT5ModelTester for this shared layer
|
||||
shared = TFSharedEmbeddings(99, 32, name="shared")
|
||||
config.use_cache = False
|
||||
main_layer = main_layer_class(config, embed_tokens=shared)
|
||||
else:
|
||||
main_layer = main_layer_class(config)
|
||||
|
@ -143,7 +143,14 @@ class TFGPT2ModelTester:
|
||||
model = TFGPT2Model(config=config)
|
||||
|
||||
# first forward pass
|
||||
output, past = model(input_ids, token_type_ids=token_type_ids)
|
||||
outputs = model(input_ids, token_type_ids=token_type_ids, use_cache=True)
|
||||
outputs_use_cache_conf = model(input_ids, token_type_ids=token_type_ids)
|
||||
outputs_no_past = model(input_ids, token_type_ids=token_type_ids, use_cache=False)
|
||||
|
||||
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
|
||||
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
|
||||
|
||||
output, past = outputs
|
||||
|
||||
# create hypothetical next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
||||
|
@ -135,7 +135,15 @@ class TFT5ModelTester:
|
||||
self.batch_size = 1
|
||||
|
||||
# first forward pass
|
||||
_, past_key_value_states = model(input_ids, use_cache=True)
|
||||
outputs = model(input_ids, use_cache=True)
|
||||
|
||||
outputs_use_cache_conf = model(input_ids)
|
||||
outputs_no_past = model(input_ids, use_cache=False)
|
||||
|
||||
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
|
||||
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
|
||||
|
||||
output, past_key_value_states = outputs
|
||||
|
||||
# create hypothetical next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
||||
|
Loading…
Reference in New Issue
Block a user