[Use cache] Align logic of use_cache with output_attentions and output_hidden_states (#5194)

* fix use cache

* add bart use cache

* fix bart

* finish bart
This commit is contained in:
Patrick von Platen 2020-06-24 16:09:17 +02:00 committed by GitHub
parent 64c393ee74
commit c2a26ec8a6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 90 additions and 21 deletions

View File

@ -815,14 +815,19 @@ class BartModel(PretrainedBartModel):
encoder_outputs: Optional[Tuple] = None, encoder_outputs: Optional[Tuple] = None,
decoder_attention_mask=None, decoder_attention_mask=None,
decoder_cached_states=None, decoder_cached_states=None,
use_cache=False, use_cache=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
): ):
if decoder_input_ids is None:
use_cache = False
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
) )
use_cache = use_cache if use_cache is not None else self.config.use_cache
# make masks if user doesn't supply # make masks if user doesn't supply
if not use_cache: if not use_cache:
@ -915,7 +920,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
decoder_attention_mask=None, decoder_attention_mask=None,
decoder_cached_states=None, decoder_cached_states=None,
labels=None, labels=None,
use_cache=False, use_cache=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
**unused, **unused,
@ -968,6 +973,9 @@ class BartForConditionalGeneration(PretrainedBartModel):
) )
labels = unused.pop("lm_labels") labels = unused.pop("lm_labels")
if labels is not None:
use_cache = False
outputs = self.model( outputs = self.model(
input_ids, input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
@ -1070,6 +1078,7 @@ class BartForSequenceClassification(PretrainedBartModel):
labels=None, labels=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
use_cache=None,
): ):
r""" r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@ -1106,6 +1115,9 @@ class BartForSequenceClassification(PretrainedBartModel):
loss, logits = outputs[:2] loss, logits = outputs[:2]
""" """
if labels is not None:
use_cache = False
outputs = self.model( outputs = self.model(
input_ids, input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
@ -1114,6 +1126,7 @@ class BartForSequenceClassification(PretrainedBartModel):
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
use_cache=use_cache,
) )
x = outputs[0] # last hidden state x = outputs[0] # last hidden state
eos_mask = input_ids.eq(self.config.eos_token_id) eos_mask = input_ids.eq(self.config.eos_token_id)
@ -1159,6 +1172,7 @@ class BartForQuestionAnswering(PretrainedBartModel):
end_positions=None, end_positions=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
use_cache=None,
): ):
r""" r"""
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@ -1206,6 +1220,8 @@ class BartForQuestionAnswering(PretrainedBartModel):
answer = ' '.join(all_tokens[torch.argmax(start_scores) : torch.argmax(end_scores)+1]) answer = ' '.join(all_tokens[torch.argmax(start_scores) : torch.argmax(end_scores)+1])
""" """
if start_positions is not None and end_positions is not None:
use_cache = False
outputs = self.model( outputs = self.model(
input_ids, input_ids,
@ -1215,6 +1231,7 @@ class BartForQuestionAnswering(PretrainedBartModel):
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
use_cache=use_cache,
) )
sequence_output = outputs[0] sequence_output = outputs[0]

View File

@ -335,7 +335,7 @@ class CTRLModel(CTRLPreTrainedModel):
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
use_cache=True, use_cache=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
): ):
@ -374,6 +374,7 @@ class CTRLModel(CTRLPreTrainedModel):
""" """
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
use_cache = use_cache if use_cache is not None else self.config.use_cache
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
) )
@ -519,7 +520,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
labels=None, labels=None,
use_cache=True, use_cache=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
): ):

View File

@ -379,7 +379,7 @@ class GPT2Model(GPT2PreTrainedModel):
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
use_cache=True, use_cache=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
): ):
@ -420,6 +420,7 @@ class GPT2Model(GPT2PreTrainedModel):
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
) )
use_cache = use_cache if use_cache is not None else self.config.use_cache
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
@ -562,7 +563,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
labels=None, labels=None,
use_cache=True, use_cache=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
): ):
@ -671,7 +672,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
mc_token_ids=None, mc_token_ids=None,
labels=None, labels=None,
mc_labels=None, mc_labels=None,
use_cache=True, use_cache=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
**kwargs **kwargs

View File

@ -659,11 +659,12 @@ class T5Stack(T5PreTrainedModel):
inputs_embeds=None, inputs_embeds=None,
head_mask=None, head_mask=None,
past_key_value_states=None, past_key_value_states=None,
use_cache=False, use_cache=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
): ):
use_cache = use_cache if use_cache is not None else self.config.use_cache
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@ -854,6 +855,7 @@ class T5Model(T5PreTrainedModel):
self.shared = nn.Embedding(config.vocab_size, config.d_model) self.shared = nn.Embedding(config.vocab_size, config.d_model)
encoder_config = copy.deepcopy(config) encoder_config = copy.deepcopy(config)
encoder_config.use_cache = False
self.encoder = T5Stack(encoder_config, self.shared) self.encoder = T5Stack(encoder_config, self.shared)
decoder_config = copy.deepcopy(config) decoder_config = copy.deepcopy(config)
@ -893,7 +895,7 @@ class T5Model(T5PreTrainedModel):
decoder_input_ids=None, decoder_input_ids=None,
decoder_attention_mask=None, decoder_attention_mask=None,
decoder_past_key_value_states=None, decoder_past_key_value_states=None,
use_cache=True, use_cache=None,
inputs_embeds=None, inputs_embeds=None,
decoder_inputs_embeds=None, decoder_inputs_embeds=None,
head_mask=None, head_mask=None,
@ -933,6 +935,7 @@ class T5Model(T5PreTrainedModel):
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
""" """
use_cache = use_cache if use_cache is not None else self.config.use_cache
# Encode if needed (training, first prediction pass) # Encode if needed (training, first prediction pass)
if encoder_outputs is None: if encoder_outputs is None:
@ -985,6 +988,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
self.shared = nn.Embedding(config.vocab_size, config.d_model) self.shared = nn.Embedding(config.vocab_size, config.d_model)
encoder_config = copy.deepcopy(config) encoder_config = copy.deepcopy(config)
encoder_config.use_cache = False
self.encoder = T5Stack(encoder_config, self.shared) self.encoder = T5Stack(encoder_config, self.shared)
decoder_config = copy.deepcopy(config) decoder_config = copy.deepcopy(config)
@ -1021,7 +1025,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
decoder_input_ids=None, decoder_input_ids=None,
decoder_attention_mask=None, decoder_attention_mask=None,
decoder_past_key_value_states=None, decoder_past_key_value_states=None,
use_cache=True, use_cache=None,
labels=None, labels=None,
inputs_embeds=None, inputs_embeds=None,
decoder_inputs_embeds=None, decoder_inputs_embeds=None,
@ -1086,6 +1090,8 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
labels = kwargs.pop("lm_labels") labels = kwargs.pop("lm_labels")
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}." assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
use_cache = use_cache if use_cache is not None else self.config.use_cache
# Encode if needed (training, first prediction pass) # Encode if needed (training, first prediction pass)
if encoder_outputs is None: if encoder_outputs is None:
# Convert encoder inputs in embeddings if needed # Convert encoder inputs in embeddings if needed

View File

@ -186,6 +186,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
super().__init__(**kwargs) super().__init__(**kwargs)
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.use_cache = config.use_cache
self.d_model_size = config.n_embd self.d_model_size = config.n_embd
self.num_layers = config.n_layer self.num_layers = config.n_layer
@ -235,7 +236,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
use_cache=True, use_cache=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
training=False, training=False,
@ -270,6 +271,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
output_attentions = output_attentions if output_attentions is not None else self.output_attentions output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
use_cache = use_cache if use_cache is not None else self.use_cache
# If using past key value states, only the last tokens # If using past key value states, only the last tokens
# should be given as an input # should be given as an input

View File

@ -215,6 +215,8 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
super().__init__(*inputs, **kwargs) super().__init__(*inputs, **kwargs)
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.use_cache = config.use_cache
self.num_hidden_layers = config.n_layer self.num_hidden_layers = config.n_layer
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.n_embd = config.n_embd self.n_embd = config.n_embd
@ -254,10 +256,10 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
use_cache=True, use_cache=None,
training=False,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
training=False,
): ):
if isinstance(inputs, (tuple, list)): if isinstance(inputs, (tuple, list)):
input_ids = inputs[0] input_ids = inputs[0]
@ -288,6 +290,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
output_attentions = output_attentions if output_attentions is not None else self.output_attentions output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
use_cache = use_cache if use_cache is not None else self.use_cache
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
@ -622,7 +625,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
mc_token_ids=None, mc_token_ids=None,
use_cache=True, use_cache=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
training=False, training=False,

View File

@ -518,6 +518,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
super().__init__(**kwargs) super().__init__(**kwargs)
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.use_cache = config.use_cache
self.embed_tokens = embed_tokens self.embed_tokens = embed_tokens
self.is_decoder = config.is_decoder self.is_decoder = config.is_decoder
@ -556,7 +557,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
inputs_embeds=None, inputs_embeds=None,
head_mask=None, head_mask=None,
past_key_value_states=None, past_key_value_states=None,
use_cache=False, use_cache=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
training=False, training=False,
@ -586,6 +587,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
output_attentions = output_attentions if output_attentions is not None else self.output_attentions output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
use_cache = use_cache if use_cache is not None else self.use_cache
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both inputs and inputs_embeds at the same time") raise ValueError("You cannot specify both inputs and inputs_embeds at the same time")
@ -874,6 +876,7 @@ class TFT5Model(TFT5PreTrainedModel):
embed_tokens = _NoLayerEmbedTokens(self.shared, abs_scope_name=shared_abs_scope_name) embed_tokens = _NoLayerEmbedTokens(self.shared, abs_scope_name=shared_abs_scope_name)
encoder_config = copy.deepcopy(config) encoder_config = copy.deepcopy(config)
encoder_config.use_cache = False
self.encoder = TFT5MainLayer(encoder_config, embed_tokens, name="encoder") self.encoder = TFT5MainLayer(encoder_config, embed_tokens, name="encoder")
decoder_config = copy.deepcopy(config) decoder_config = copy.deepcopy(config)
@ -952,11 +955,13 @@ class TFT5Model(TFT5PreTrainedModel):
decoder_attention_mask = kwargs.get("decoder_attention_mask", None) decoder_attention_mask = kwargs.get("decoder_attention_mask", None)
decoder_inputs_embeds = kwargs.get("decoder_inputs_embeds", None) decoder_inputs_embeds = kwargs.get("decoder_inputs_embeds", None)
decoder_past_key_value_states = kwargs.get("decoder_past_key_value_states", None) decoder_past_key_value_states = kwargs.get("decoder_past_key_value_states", None)
use_cache = kwargs.get("use_cache", True) use_cache = kwargs.get("use_cache", None)
head_mask = kwargs.get("head_mask", None) head_mask = kwargs.get("head_mask", None)
output_attentions = kwargs.get("output_attentions", None) output_attentions = kwargs.get("output_attentions", None)
output_hidden_states = kwargs.get("output_hidden_states", None) output_hidden_states = kwargs.get("output_hidden_states", None)
use_cache = use_cache if use_cache is not None else self.config.use_cache
# Encode if needed (training, first prediction pass) # Encode if needed (training, first prediction pass)
if encoder_outputs is None: if encoder_outputs is None:
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
@ -1014,6 +1019,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel):
embed_tokens = _NoLayerEmbedTokens(self.shared, abs_scope_name=shared_abs_scope_name) embed_tokens = _NoLayerEmbedTokens(self.shared, abs_scope_name=shared_abs_scope_name)
encoder_config = copy.deepcopy(config) encoder_config = copy.deepcopy(config)
encoder_config.use_cache = False
self.encoder = TFT5MainLayer(encoder_config, embed_tokens, name="encoder") self.encoder = TFT5MainLayer(encoder_config, embed_tokens, name="encoder")
decoder_config = copy.deepcopy(config) decoder_config = copy.deepcopy(config)
@ -1095,13 +1101,15 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel):
encoder_outputs = kwargs.get("encoder_outputs", None) encoder_outputs = kwargs.get("encoder_outputs", None)
decoder_attention_mask = kwargs.get("decoder_attention_mask", None) decoder_attention_mask = kwargs.get("decoder_attention_mask", None)
decoder_past_key_value_states = kwargs.get("decoder_past_key_value_states", None) decoder_past_key_value_states = kwargs.get("decoder_past_key_value_states", None)
use_cache = kwargs.get("use_cache", True) use_cache = kwargs.get("use_cache", None)
inputs_embeds = kwargs.get("inputs_embeds", None) inputs_embeds = kwargs.get("inputs_embeds", None)
decoder_inputs_embeds = kwargs.get("decoder_inputs_embeds", None) decoder_inputs_embeds = kwargs.get("decoder_inputs_embeds", None)
head_mask = kwargs.get("head_mask", None) head_mask = kwargs.get("head_mask", None)
output_attentions = kwargs.get("output_attentions", None) output_attentions = kwargs.get("output_attentions", None)
output_hidden_states = kwargs.get("output_hidden_states", None) output_hidden_states = kwargs.get("output_hidden_states", None)
use_cache = use_cache if use_cache is not None else self.config.use_cache
# Encode if needed (training, first prediction pass) # Encode if needed (training, first prediction pass)
if encoder_outputs is None: if encoder_outputs is None:
# Convert encoder inputs in embeddings if needed # Convert encoder inputs in embeddings if needed

View File

@ -153,6 +153,7 @@ class BARTModelTest(ModelTesterMixin, unittest.TestCase):
def test_advanced_inputs(self): def test_advanced_inputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.use_cache = False
inputs_dict["input_ids"][:, -2:] = config.pad_token_id inputs_dict["input_ids"][:, -2:] = config.pad_token_id
decoder_input_ids, decoder_attn_mask, causal_mask = _prepare_bart_decoder_inputs( decoder_input_ids, decoder_attn_mask, causal_mask = _prepare_bart_decoder_inputs(
config, inputs_dict["input_ids"] config, inputs_dict["input_ids"]

View File

@ -168,7 +168,14 @@ class GPT2ModelTester:
model.eval() model.eval()
# first forward pass # first forward pass
output, past = model(input_ids, token_type_ids=token_type_ids) outputs = model(input_ids, token_type_ids=token_type_ids, use_cache=True)
outputs_use_cache_conf = model(input_ids, token_type_ids=token_type_ids)
outputs_no_past = model(input_ids, token_type_ids=token_type_ids, use_cache=False)
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
output, past = outputs
# create hypothetical next token and extent to next_input_ids # create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)

View File

@ -193,7 +193,14 @@ class T5ModelTester:
model.eval() model.eval()
# first forward pass # first forward pass
output, past_key_value_states = model(input_ids, use_cache=True) outputs = model(input_ids, use_cache=True)
outputs_use_cache_conf = model(input_ids)
outputs_no_past = model(input_ids, use_cache=False)
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
output, past_key_value_states = outputs
# create hypothetical next token and extent to next_input_ids # create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)

View File

@ -126,6 +126,7 @@ class TFModelTesterMixin:
if "T5" in main_layer_class.__name__: if "T5" in main_layer_class.__name__:
# Take the same values than in TFT5ModelTester for this shared layer # Take the same values than in TFT5ModelTester for this shared layer
shared = TFSharedEmbeddings(99, 32, name="shared") shared = TFSharedEmbeddings(99, 32, name="shared")
config.use_cache = False
main_layer = main_layer_class(config, embed_tokens=shared) main_layer = main_layer_class(config, embed_tokens=shared)
else: else:
main_layer = main_layer_class(config) main_layer = main_layer_class(config)

View File

@ -143,7 +143,14 @@ class TFGPT2ModelTester:
model = TFGPT2Model(config=config) model = TFGPT2Model(config=config)
# first forward pass # first forward pass
output, past = model(input_ids, token_type_ids=token_type_ids) outputs = model(input_ids, token_type_ids=token_type_ids, use_cache=True)
outputs_use_cache_conf = model(input_ids, token_type_ids=token_type_ids)
outputs_no_past = model(input_ids, token_type_ids=token_type_ids, use_cache=False)
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
output, past = outputs
# create hypothetical next token and extent to next_input_ids # create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)

View File

@ -135,7 +135,15 @@ class TFT5ModelTester:
self.batch_size = 1 self.batch_size = 1
# first forward pass # first forward pass
_, past_key_value_states = model(input_ids, use_cache=True) outputs = model(input_ids, use_cache=True)
outputs_use_cache_conf = model(input_ids)
outputs_no_past = model(input_ids, use_cache=False)
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
output, past_key_value_states = outputs
# create hypothetical next token and extent to next_input_ids # create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)