TF generate refactor - past without encoder outputs (#15944)

* Remove packed past from generation_tf_utils

* update models with the new past format

* update template accordingly
This commit is contained in:
Joao Gante 2022-03-08 14:46:44 +00:00 committed by GitHub
parent 62d847602a
commit 70203b5937
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
30 changed files with 301 additions and 684 deletions

View File

@ -867,9 +867,8 @@ class TFGenerationMixin:
beam_scores = tf.reshape(beam_scores, (batch_size * num_beams,))
# cache compute states
past = encoder_outputs
# to stay similar to torch : past = (encoder_outputs, None) if encoder_outputs is not None else None
# variable to cache compute states
past = None
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and kwargs["output_scores"]) else None
@ -886,6 +885,13 @@ class TFGenerationMixin:
if (return_dict_in_generate and kwargs["encoder_hidden_states"])
else None
)
# the refactored generate, without the encoder outputs in `past`, expects the `encoder_outputs`
# variable to contain all (encoder_outputs, encoder_hidden_states, encoder_attentions) in
# `prepare_inputs_for_generation`
if encoder_hidden_states is not None:
encoder_outputs = (*encoder_outputs, encoder_hidden_states)
if encoder_attentions is not None:
encoder_outputs = (*encoder_outputs, encoder_attentions)
# done sentences
done = [False for _ in range(batch_size)]
@ -896,6 +902,7 @@ class TFGenerationMixin:
past=past,
attention_mask=attention_mask,
use_cache=use_cache,
encoder_outputs=encoder_outputs,
**kwargs,
)
outputs = self(
@ -1486,14 +1493,10 @@ class TFGenerationMixin:
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(input_ids, pad_token_id)
# 4. Prepare model inputs which will be used for auto-regressive generation
if self.config.is_encoder_decoder:
# if model is encoder decoder model, we create encoder_outputs and add to `model_kwargs`
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
input_ids, return_dict_in_generate, model_kwargs
)
# 4. Prepare `input_ids` which will be used for auto-regressive generation
if self.config.is_encoder_decoder:
# if encoder-decoder, we create encoder_outputs and add to `model_kwargs`
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, model_kwargs)
# if encoder-decoder then `input_ids` come from `decoder_start_token_id`
input_ids = self._prepare_decoder_input_ids_for_generation(
batch_size,
@ -1531,10 +1534,6 @@ class TFGenerationMixin:
f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search."
)
# TODO(Patrick) - ugly `past`/`encoder_output` hack here which requires a bigger refactor of all
# generation models in TF. `past` should be optional everywhere and not be set equal to encoder_outputs.
model_kwargs["past"] = model_kwargs.get("encoder_outputs")[:1] if self.config.is_encoder_decoder else None
# 8. run greedy search
return self.greedy_search(
input_ids,
@ -1559,10 +1558,6 @@ class TFGenerationMixin:
**model_kwargs,
)
# TODO(Patrick) - ugly `past`/`encoder_output` hack here which requires a bigger refactor of all
# generation models in TF. `past` should be optional everywhere and not be set equal to encoder_outputs.
model_kwargs["past"] = model_kwargs.get("encoder_outputs")[:1] if self.config.is_encoder_decoder else None
# 10. run sample
return self.sample(
input_ids,
@ -1589,12 +1584,7 @@ class TFGenerationMixin:
else:
return tf.ones(input_ids.shape[:2], dtype=tf.int32)
def _prepare_encoder_decoder_kwargs_for_generation(
self, input_ids: tf.Tensor, return_dict_in_generate, model_kwargs
) -> Dict[str, Any]:
# TODO(Patrick) - remove `return_dict_in_generate` flag input once `past`/`encoder_outputs`
# is cleaned
def _prepare_encoder_decoder_kwargs_for_generation(self, input_ids: tf.Tensor, model_kwargs) -> Dict[str, Any]:
# get encoder and store encoder outputs
encoder = self.get_encoder()
@ -1612,17 +1602,8 @@ class TFGenerationMixin:
encoder_kwargs.pop("attention_mask")
encoder_outputs = encoder(input_ids, **encoder_kwargs)
model_kwargs["encoder_outputs"] = encoder_outputs
# TODO(Patrick): `encoder_outputs`, `past` hack. Currently, `encoder_attentions` and
# `encoder_hidden_states` have to be seperated from encoder_outputs and passed
# under other names because of `encoder_outputs`, `past` hack. Need to clean-up
# all encoder-decoder prepare_inputs_for_generation method to clean this
if return_dict_in_generate:
model_kwargs["encoder_attentions"] = encoder_outputs.get("attentions", None)
model_kwargs["encoder_hidden_states"] = encoder_outputs.get("hidden_states", None)
return model_kwargs
def _prepare_decoder_input_ids_for_generation(
@ -1712,27 +1693,17 @@ class TFGenerationMixin:
return inputs
@staticmethod
def _update_model_kwargs_for_generation(
self, outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False
outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False
) -> Dict[str, Any]:
# update past
if self._use_cache(outputs, model_kwargs["use_cache"]):
# TODO(Patrick): `past`/`encoder_outputs` hack. This should be
# removed when cleaning up the encoder-decoder models
# if model has past, then set the past variable to speed up decoding
# make this method static then as well
model_kwargs["past"] = outputs[1]
elif "past_key_values" in outputs:
if "past_key_values" in outputs:
model_kwargs["past"] = outputs.past_key_values
elif "mems" in outputs:
model_kwargs["past"] = outputs.mems
elif "past_buckets_states" in outputs:
model_kwargs["past"] = outputs.past_buckets_states
elif "past" in model_kwargs:
# TODO(Patrick) `past`/`encoder_outputs` hack.
# removed when cleaning up the encoder-decoder models.
# The line should not be necessary.
pass
else:
model_kwargs["past"] = None
@ -1907,26 +1878,18 @@ class TFGenerationMixin:
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# TODO(Patrick): `encoder_outputs`, `past` hack. Currently T5, Bart expect `encoder_outputs`
# to be wrapped into `past` variable. Tis is a bad design and needs
# to be updated.
# Remove the following lines when updating all encoder-decoder models
encoder_outputs = model_kwargs.pop("encoder_outputs", None)
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self.config.is_encoder_decoder:
encoder_attentions = encoder_outputs.get("attentions") if output_attentions else None
encoder_hidden_states = encoder_outputs.get("hidden_states") if output_hidden_states else None
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
encoder_hidden_states = (
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
)
# keep track of which sequences are already finished
unfinished_sequences = tf.ones_like(input_ids[:, 0])
cur_len = input_ids.shape[-1]
while cur_len < max_length:
# TODO(Patrick): remove following line by cleaning up `prepare_inputs_for_generation`
# in all models
model_kwargs["use_cache"] = None if "use_cache" not in model_kwargs else model_kwargs["use_cache"]
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
@ -2129,25 +2092,18 @@ class TFGenerationMixin:
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# TODO(Patrick): `encoder_outputs`, `past` hack. Currently T5, Bart expect `encoder_outputs`
# to be wrapped into `past` variable. This is a bad design and needs to be updated.
# Remove the following lines when updating all encoder-decoder models
encoder_outputs = model_kwargs.pop("encoder_outputs", None)
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self.config.is_encoder_decoder:
encoder_attentions = encoder_outputs.get("attentions") if output_attentions else None
encoder_hidden_states = encoder_outputs.get("hidden_states") if output_hidden_states else None
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
encoder_hidden_states = (
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
)
# keep track of which sequences are already finished
unfinished_sequences = tf.ones_like(input_ids[:, 0])
cur_len = input_ids.shape[-1]
while cur_len < max_length:
# TODO(Patrick): remove following line by cleaning up `prepare_inputs_for_generation`
# in all models
model_kwargs["use_cache"] = None if "use_cache" not in model_kwargs else model_kwargs["use_cache"]
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

View File

@ -16,7 +16,7 @@
import random
from typing import Dict, Optional, Tuple, Union
from typing import Optional, Tuple, Union
import tensorflow as tf
@ -1012,9 +1012,6 @@ class TFBartDecoder(tf.keras.layers.Layer):
if inputs["output_hidden_states"]:
all_hidden_states += (hidden_states,)
if inputs["use_cache"]:
present_key_values = (inputs["encoder_hidden_states"], present_key_values)
if not inputs["return_dict"]:
return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns
else:
@ -1449,43 +1446,23 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past,
attention_mask,
past=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
use_cache=None,
**kwargs,
) -> Dict:
assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}"
if len(past) == 1:
assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}"
encoder_outputs = TFBaseModelOutput(last_hidden_state=past[0])
past_key_values = None
else:
assert (
len(past) == 2
), "`past` has to be of length 2 with the encoder_outputs at the first position and past_key_values at the second position."
encoder_outputs, past_key_values = past
if isinstance(encoder_outputs, tuple):
assert isinstance(
encoder_outputs[0], tf.Tensor
), f"`encoder_outputs[0]` has to be of type `tf.Tensor`, but is {type(encoder_outputs[0])}"
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs[0])
elif isinstance(encoder_outputs, tf.Tensor):
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs)
assert (
past_key_values
), f"decoder cached states must be truthy. got {past_key_values} from the 2nd element of past"
encoder_outputs=None,
**kwargs
):
# cut decoder_input_ids if past is used
if past is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
assert isinstance(
encoder_outputs, TFBaseModelOutput
), f"encoder_outputs should be a TFBaseModelOutput, Instead got {type(encoder_outputs)}."
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past_key_values,
"past_key_values": past,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,
@ -1499,15 +1476,10 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
@staticmethod
def _reorder_cache(past, beam_idx):
if len(past) == 1:
return past
past_key_values = past[1]
reordered_past = ()
for layer_past_key_values in past_key_values:
for layer_past in past:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += (
tuple(tf.gather(layer_past_key_value, beam_idx) for layer_past_key_value in layer_past_key_values[:2])
+ layer_past_key_values[2:],
tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:],
)
return (past[0], reordered_past)
return reordered_past

View File

@ -1443,17 +1443,17 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name
def prepare_inputs_for_generation(self, inputs, past=None, attention_mask=None, **model_kwargs):
# cut decoder_input_ids if past is used
if past:
inputs = tf.expand_dims(inputs[:, -1], -1)
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = tf.ones(input_shape)
return {
"input_ids": inputs,
"attention_mask": attention_mask,
"past_key_values": past,
"use_cache": model_kwargs["use_cache"],
}
# cut decoder_input_ids if past is used
if past is not None:
input_ids = input_ids[:, -1:]
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
@ -1575,6 +1575,13 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns
)
@staticmethod
def _reorder_cache(past, beam_idx):
reordered_past = ()
for layer_past in past:
reordered_past += (tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past),)
return reordered_past
@add_start_docstrings(
"""Bert Model with a `next sentence prediction (classification)` head on top.""",

View File

@ -18,7 +18,7 @@
import os
import random
import warnings
from typing import Dict, Optional, Tuple, Union
from typing import Optional, Tuple, Union
import tensorflow as tf
@ -1011,9 +1011,6 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
if inputs["output_hidden_states"]:
all_hidden_states += (hidden_states,)
if inputs["use_cache"]:
present_key_values = (inputs["encoder_hidden_states"], present_key_values)
if not inputs["return_dict"]:
return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns
else:
@ -1461,43 +1458,23 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past,
attention_mask,
past=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
use_cache=None,
**kwargs,
) -> Dict:
assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}"
if len(past) == 1:
assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}"
encoder_outputs = TFBaseModelOutput(last_hidden_state=past[0])
past_key_values = None
else:
assert (
len(past) == 2
), "`past` has to be of length 2 with the encoder_outputs at the first position and past_key_values at the second position."
encoder_outputs, past_key_values = past
if isinstance(encoder_outputs, tuple):
assert isinstance(
encoder_outputs[0], tf.Tensor
), f"`encoder_outputs[0]` has to be of type `tf.Tensor`, but is {type(encoder_outputs[0])}"
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs[0])
elif isinstance(encoder_outputs, tf.Tensor):
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs)
assert (
past_key_values
), f"decoder cached states must be truthy. got {past_key_values} from the 2nd element of past"
encoder_outputs=None,
**kwargs
):
# cut decoder_input_ids if past is used
if past is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
assert isinstance(
encoder_outputs, TFBaseModelOutput
), f"encoder_outputs should be a TFBaseModelOutput, Instead got {type(encoder_outputs)}."
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past_key_values,
"past_key_values": past,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,
@ -1509,15 +1486,10 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal
@staticmethod
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration._reorder_cache
def _reorder_cache(past, beam_idx):
if len(past) == 1:
return past
past_key_values = past[1]
reordered_past = ()
for layer_past_key_values in past_key_values:
for layer_past in past:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += (
tuple(tf.gather(layer_past_key_value, beam_idx) for layer_past_key_value in layer_past_key_values[:2])
+ layer_past_key_values[2:],
tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:],
)
return (past[0], reordered_past)
return reordered_past

View File

@ -16,7 +16,7 @@
import random
from typing import Dict, Optional, Tuple, Union
from typing import Optional, Tuple, Union
import tensorflow as tf
@ -1010,9 +1010,6 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
if inputs["output_hidden_states"]:
all_hidden_states += (hidden_states,)
if inputs["use_cache"]:
present_key_values = (inputs["encoder_hidden_states"], present_key_values)
if not inputs["return_dict"]:
return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns
else:
@ -1434,43 +1431,23 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past,
attention_mask,
past=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
use_cache=None,
**kwargs,
) -> Dict:
assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}"
if len(past) == 1:
assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}"
encoder_outputs = TFBaseModelOutput(last_hidden_state=past[0])
past_key_values = None
else:
assert (
len(past) == 2
), "`past` has to be of length 2 with the encoder_outputs at the first position and past_key_values at the second position."
encoder_outputs, past_key_values = past
if isinstance(encoder_outputs, tuple):
assert isinstance(
encoder_outputs[0], tf.Tensor
), f"`encoder_outputs[0]` has to be of type `tf.Tensor`, but is {type(encoder_outputs[0])}"
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs[0])
elif isinstance(encoder_outputs, tf.Tensor):
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs)
assert (
past_key_values
), f"decoder cached states must be truthy. got {past_key_values} from the 2nd element of past"
encoder_outputs=None,
**kwargs
):
# cut decoder_input_ids if past is used
if past is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
assert isinstance(
encoder_outputs, TFBaseModelOutput
), f"encoder_outputs should be a TFBaseModelOutput, Instead got {type(encoder_outputs)}."
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past_key_values,
"past_key_values": past,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,
@ -1482,15 +1459,10 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
@staticmethod
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration._reorder_cache
def _reorder_cache(past, beam_idx):
if len(past) == 1:
return past
past_key_values = past[1]
reordered_past = ()
for layer_past_key_values in past_key_values:
for layer_past in past:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += (
tuple(tf.gather(layer_past_key_value, beam_idx) for layer_past_key_value in layer_past_key_values[:2])
+ layer_past_key_values[2:],
tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:],
)
return (past[0], reordered_past)
return reordered_past

View File

@ -16,6 +16,7 @@
""" TF 2.0 CTRL model."""
import warnings
from typing import Tuple
import numpy as np
import tensorflow as tf
@ -659,12 +660,12 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.lm_head.name
def prepare_inputs_for_generation(self, inputs, past, **kwargs):
def prepare_inputs_for_generation(self, input_ids, past=None, use_cache=None, **kwargs):
# only last token for inputs_ids if past is defined in kwargs
if past:
inputs = tf.expand_dims(inputs[:, -1], -1)
input_ids = tf.expand_dims(input_ids[:, -1], -1)
return {"input_ids": inputs, "past": past, "use_cache": kwargs["use_cache"]}
return {"input_ids": input_ids, "past_key_values": past, "use_cache": use_cache}
@add_start_docstrings_to_model_forward(CTRL_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
@ -758,6 +759,12 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
return TFCausalLMOutputWithPast(logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns)
@staticmethod
def _reorder_cache(past: Tuple[Tuple[tf.Tensor]], beam_idx: tf.Tensor) -> Tuple[Tuple[tf.Tensor]]:
return tuple(
tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past) for layer_past in past
)
@add_start_docstrings(
"""

View File

@ -692,52 +692,21 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
)
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past,
attention_mask,
use_cache=None,
**kwargs,
self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
):
if past is None or len(past) not in {1, 2}:
raise ValueError(f"past has to be an iterable of length 1,2 got {past}")
if len(past) == 1:
if not isinstance(past[0], tf.Tensor):
raise ValueError(f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}")
encoder_outputs = TFBaseModelOutput(last_hidden_state=past[0])
past_key_values = None
else:
if len(past) != 2:
raise ValueError(
"`past` has to be of length 2 with the encoder_outputs at the first position and past_key_values at the second position."
)
encoder_outputs, past_key_values = past
if isinstance(encoder_outputs, tuple):
if not isinstance(encoder_outputs[0], tf.Tensor):
raise ValueError(
f"`encoder_outputs[0]` has to be of type `tf.Tensor`, but is {type(encoder_outputs[0])}"
)
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs[0])
elif isinstance(encoder_outputs, tf.Tensor):
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs)
if not past_key_values:
raise ValueError(
f"decoder cached states must be truthy. got {past_key_values} from the 2nd element of past"
)
decoder_input_ids = decoder_input_ids[:, -1:]
if not isinstance(encoder_outputs, TFBaseModelOutput):
raise ValueError(f"encoder_outputs should be a TFBaseModelOutput, Instead got {type(encoder_outputs)}.")
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past=past)
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
input_dict = {
"input_ids": None, # needs to be passed to make Keras.layer.__call__ happy
"attention_mask": attention_mask,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
"decoder_attention_mask": decoder_attention_mask,
"decoder_input_ids": decoder_inputs["input_ids"],
# TODO (joao): the `TFBaseModelOutput` wrapper should not be needed after the generate refactor is complete
"encoder_outputs": TFBaseModelOutput(last_hidden_state=encoder_outputs[0]),
"past_key_values": decoder_inputs["past_key_values"],
"use_cache": use_cache,
}
return input_dict
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
@ -750,9 +719,4 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
def _reorder_cache(self, past, beam_idx):
# apply decoder cache reordering here
if len(past) == 1:
return past
encoder_outputs, past_key_values = past
return (encoder_outputs, self.decoder._reorder_cache(past_key_values, beam_idx))
return self.decoder._reorder_cache(past, beam_idx)

View File

@ -851,12 +851,15 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
def set_output_embeddings(self, value):
self.set_input_embeddings(value)
def prepare_inputs_for_generation(self, inputs, past, **kwargs):
def prepare_inputs_for_generation(self, inputs, past=None, use_cache=None, **kwargs):
# TODO: (Joao) after the TF generator is complete, update GPT2 TF generation to match PT's. NB -- some GPT2
# tests will need to be fixed after the change
# only last token for inputs_ids if past is defined in kwargs
if past:
inputs = tf.expand_dims(inputs[:, -1], -1)
return {"input_ids": inputs, "past": past, "use_cache": kwargs["use_cache"]}
return {"input_ids": inputs, "past_key_values": past, "use_cache": use_cache}
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
@add_code_sample_docstrings(

View File

@ -17,7 +17,7 @@
import random
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union
import tensorflow as tf
@ -2097,7 +2097,7 @@ class TFLEDDecoder(tf.keras.layers.Layer):
all_self_attns = all_self_attns if inputs["output_attentions"] else None
all_cross_attentions = all_cross_attentions if inputs["output_attentions"] else None
present_key_values = (encoder_hidden_states, present_key_values) if inputs["use_cache"] else None
present_key_values = present_key_values if inputs["use_cache"] else None
if not inputs["return_dict"]:
return tuple(
@ -2527,45 +2527,26 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past,
attention_mask,
past=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
use_cache=None,
encoder_outputs=None,
**kwargs,
) -> Dict:
assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}"
if len(past) == 1:
assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}"
encoder_outputs = TFLEDEncoderBaseModelOutput(last_hidden_state=past[0])
past_key_values = None
else:
assert (
len(past) == 2
), "`past` has to be of length 2 with the encoder_outputs at the first position and past_key_values at the second position."
encoder_outputs, past_key_values = past
if isinstance(encoder_outputs, tuple):
assert isinstance(
encoder_outputs[0], tf.Tensor
), f"`encoder_outputs[0]` has to be of type `tf.Tensor`, but is {type(encoder_outputs[0])}"
encoder_outputs = TFLEDEncoderBaseModelOutput(last_hidden_state=encoder_outputs[0])
elif isinstance(encoder_outputs, tf.Tensor):
encoder_outputs = TFLEDEncoderBaseModelOutput(last_hidden_state=encoder_outputs)
assert (
past_key_values
), f"decoder cached states must be truthy. got {past_key_values} from the 2nd element of past"
):
# cut decoder_input_ids if past is used
if past is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
assert isinstance(
encoder_outputs,
TFLEDEncoderBaseModelOutput,
), f"encoder_outputs should be a TFLEDEncoderBaseModelOutput, Instead got {type(encoder_outputs)}."
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past_key_values,
"past_key_values": past,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
}
@ -2574,18 +2555,13 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
@staticmethod
def _reorder_cache(past, beam_idx):
if len(past) == 1:
return past
past_key_values = past[1]
reordered_past = ()
for layer_past_key_values in past_key_values:
for layer_past in past:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += (
tuple(tf.gather(layer_past_key_value, beam_idx) for layer_past_key_value in layer_past_key_values[:2])
+ layer_past_key_values[2:],
tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:],
)
return (past[0], reordered_past)
return reordered_past
def hf_compute_loss(self, labels, logits):
"""CrossEntropyLoss that ignores pad tokens"""

View File

@ -16,7 +16,7 @@
import random
from typing import Dict, Optional, Tuple, Union
from typing import Optional, Tuple, Union
import numpy as np
import tensorflow as tf
@ -1050,9 +1050,6 @@ class TFMarianDecoder(tf.keras.layers.Layer):
if inputs["output_hidden_states"]:
all_hidden_states += (hidden_states,)
if inputs["use_cache"]:
present_key_values = (inputs["encoder_hidden_states"], present_key_values)
if not inputs["return_dict"]:
return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns
else:
@ -1477,43 +1474,23 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past,
attention_mask,
past=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
use_cache=None,
**kwargs,
) -> Dict:
assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}"
if len(past) == 1:
assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}"
encoder_outputs = TFBaseModelOutput(last_hidden_state=past[0])
past_key_values = None
else:
assert (
len(past) == 2
), "`past` has to be of length 2 with the encoder_outputs at the first position and past_key_values at the second position."
encoder_outputs, past_key_values = past
if isinstance(encoder_outputs, tuple):
assert isinstance(
encoder_outputs[0], tf.Tensor
), f"`encoder_outputs[0]` has to be of type `tf.Tensor`, but is {type(encoder_outputs[0])}"
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs[0])
elif isinstance(encoder_outputs, tf.Tensor):
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs)
assert (
past_key_values
), f"decoder cached states must be truthy. got {past_key_values} from the 2nd element of past"
encoder_outputs=None,
**kwargs
):
# cut decoder_input_ids if past is used
if past is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
assert isinstance(
encoder_outputs, TFBaseModelOutput
), f"encoder_outputs should be a TFBaseModelOutput, Instead got {type(encoder_outputs)}."
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past_key_values,
"past_key_values": past,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,
@ -1528,18 +1505,13 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
@staticmethod
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration._reorder_cache
def _reorder_cache(past, beam_idx):
if len(past) == 1:
return past
past_key_values = past[1]
reordered_past = ()
for layer_past_key_values in past_key_values:
for layer_past in past:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += (
tuple(tf.gather(layer_past_key_value, beam_idx) for layer_past_key_value in layer_past_key_values[:2])
+ layer_past_key_values[2:],
tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:],
)
return (past[0], reordered_past)
return reordered_past
def adjust_logits_during_generation(
self, logits, cur_len, max_length, forced_bos_token_id, forced_eos_token_id, **kwargs

View File

@ -16,7 +16,7 @@
import random
from typing import Dict, Optional, Tuple, Union
from typing import Optional, Tuple, Union
import tensorflow as tf
@ -1034,9 +1034,6 @@ class TFMBartDecoder(tf.keras.layers.Layer):
if inputs["output_hidden_states"]:
all_hidden_states += (hidden_states,)
if inputs["use_cache"]:
present_key_values = (inputs["encoder_hidden_states"], present_key_values)
if not inputs["return_dict"]:
return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns
else:
@ -1462,43 +1459,23 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past,
attention_mask,
past=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
use_cache=None,
**kwargs,
) -> Dict:
assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}"
if len(past) == 1:
assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}"
encoder_outputs = TFBaseModelOutput(last_hidden_state=past[0])
past_key_values = None
else:
assert (
len(past) == 2
), "`past` has to be of length 2 with the encoder_outputs at the first position and past_key_values at the second position."
encoder_outputs, past_key_values = past
if isinstance(encoder_outputs, tuple):
assert isinstance(
encoder_outputs[0], tf.Tensor
), f"`encoder_outputs[0]` has to be of type `tf.Tensor`, but is {type(encoder_outputs[0])}"
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs[0])
elif isinstance(encoder_outputs, tf.Tensor):
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs)
assert (
past_key_values
), f"decoder cached states must be truthy. got {past_key_values} from the 2nd element of past"
encoder_outputs=None,
**kwargs
):
# cut decoder_input_ids if past is used
if past is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
assert isinstance(
encoder_outputs, TFBaseModelOutput
), f"encoder_outputs should be a TFBaseModelOutput, Instead got {type(encoder_outputs)}."
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past_key_values,
"past_key_values": past,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,
@ -1513,15 +1490,10 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
@staticmethod
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration._reorder_cache
def _reorder_cache(past, beam_idx):
if len(past) == 1:
return past
past_key_values = past[1]
reordered_past = ()
for layer_past_key_values in past_key_values:
for layer_past in past:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += (
tuple(tf.gather(layer_past_key_value, beam_idx) for layer_past_key_value in layer_past_key_values[:2])
+ layer_past_key_values[2:],
tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:],
)
return (past[0], reordered_past)
return reordered_past

View File

@ -16,7 +16,7 @@
import random
from typing import Dict, Optional, Tuple, Union
from typing import Optional, Tuple, Union
import numpy as np
import tensorflow as tf
@ -1058,9 +1058,6 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
if inputs["output_hidden_states"]:
all_hidden_states += (hidden_states,)
if inputs["use_cache"]:
present_key_values = (inputs["encoder_hidden_states"], present_key_values)
if not inputs["return_dict"]:
return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns
else:
@ -1485,43 +1482,23 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past,
attention_mask,
past=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
use_cache=None,
**kwargs,
) -> Dict:
assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}"
if len(past) == 1:
assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}"
encoder_outputs = TFBaseModelOutput(last_hidden_state=past[0])
past_key_values = None
else:
assert (
len(past) == 2
), "`past` has to be of length 2 with the encoder_outputs at the first position and past_key_values at the second position."
encoder_outputs, past_key_values = past
if isinstance(encoder_outputs, tuple):
assert isinstance(
encoder_outputs[0], tf.Tensor
), f"`encoder_outputs[0]` has to be of type `tf.Tensor`, but is {type(encoder_outputs[0])}"
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs[0])
elif isinstance(encoder_outputs, tf.Tensor):
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs)
assert (
past_key_values
), f"decoder cached states must be truthy. got {past_key_values} from the 2nd element of past"
encoder_outputs=None,
**kwargs
):
# cut decoder_input_ids if past is used
if past is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
assert isinstance(
encoder_outputs, TFBaseModelOutput
), f"encoder_outputs should be a TFBaseModelOutput, Instead got {type(encoder_outputs)}."
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past_key_values,
"past_key_values": past,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,
@ -1536,15 +1513,10 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
@staticmethod
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration._reorder_cache
def _reorder_cache(past, beam_idx):
if len(past) == 1:
return past
past_key_values = past[1]
reordered_past = ()
for layer_past_key_values in past_key_values:
for layer_past in past:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += (
tuple(tf.gather(layer_past_key_value, beam_idx) for layer_past_key_value in layer_past_key_values[:2])
+ layer_past_key_values[2:],
tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:],
)
return (past[0], reordered_past)
return reordered_past

View File

@ -16,14 +16,13 @@
"""TFRAG model implementation."""
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
from typing import List, Optional, Tuple
import numpy as np
import tensorflow as tf
from ...configuration_utils import PretrainedConfig
from ...file_utils import ModelOutput, add_start_docstrings_to_model_forward, replace_return_docstrings
from ...modeling_tf_outputs import TFBaseModelOutput
from ...modeling_tf_utils import TFCausalLanguageModelingLoss, TFPreTrainedModel, input_processing, shape_list
from ...utils import logging
from .configuration_rag import RagConfig
@ -788,42 +787,28 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
# Adapted from https://github.com/huggingface/transformers/blob/master/src/transformers/modeling_tf_bart.py
def prepare_inputs_for_generation(
self, decoder_input_ids, past, attention_mask, use_cache, doc_scores, n_docs=None, **kwargs
) -> Dict:
assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}"
if len(past) == 1:
assert isinstance(past[0], tf.Tensor)
encoder_outputs = TFBaseModelOutput(last_hidden_state=past[0])
decoder_cached_states = None
else:
assert len(past) == 2
# Note: encoder_outputs is never changed by Bart as a generator
encoder_outputs, decoder_cached_states = past
if isinstance(encoder_outputs, tuple):
assert isinstance(encoder_outputs[0], tf.Tensor)
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs[0])
elif isinstance(encoder_outputs, tf.Tensor):
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs)
assert (
decoder_cached_states
), f"decoder cached states must be truthy. got {decoder_cached_states} from the 2nd element of past"
# if past is defined cut decoder_input_ids to last token
self,
decoder_input_ids,
past=None,
attention_mask=None,
use_cache=None,
encoder_outputs=None,
doc_scores=None,
n_docs=None,
**kwargs
):
if past is not None:
# if past is defined use only last decoder_input_ids
decoder_input_ids = decoder_input_ids[:, -1:]
assert isinstance(
encoder_outputs, TFBaseModelOutput
), f"encoder_outputs should be a TFBaseModelOutput, Instead got {type(encoder_outputs)}."
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"input_ids": None,
"encoder_outputs": encoder_outputs,
"doc_scores": doc_scores,
"context_attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
"past_key_values": decoder_cached_states,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
"past_key_values": past,
"use_cache": use_cache,
"do_marginalize": True,
"n_docs": n_docs,
}
@ -844,46 +829,19 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
def _reorder_cache(past, beam_idx):
"""Reorders cache for generation. BART-inspired but we need to take care of the extra dimension for docs"""
def tf_index_select(input_, dim, indices):
"""
Input:
input_(tensor): input tensor dim(int): dimension indices(list): selected indices list
Output:
mimic of torch_tensor.index_select(dim, indices)
credit:
https://stackoverflow.com/questions/58464790/is-there-an-equivalent-function-of-pytorch-named-index-select-in-tensorflow
"""
shape = shape_list(input_)
if dim == -1:
dim = len(shape) - 1
shape[dim] = 1
tmp = []
for idx in indices:
begin = [0] * len(shape)
begin[dim] = idx
tmp.append(tf.slice(input_, begin, shape))
res = tf.concat(tmp, axis=dim)
return res
def _reorder_stacked(hidden_states, new_order=beam_idx):
def _reorder_stacked(hidden_states, new_order):
n_docs = hidden_states.shape[0] // new_order.shape[0]
hidden_states = tf.reshape(hidden_states, (-1, n_docs, *hidden_states.shape[1:]))
hidden_states = tf_index_select(hidden_states, 0, new_order)
return tf.reshape(hidden_states, (-1, *hidden_states.shape[2:]))
if len(past) == 1:
return past
past_key_values = past[1]
hidden_states = tf.gather(hidden_states, new_order, axis=0)
result = tf.reshape(hidden_states, (-1, *hidden_states.shape[2:]))
return result
reordered_past = ()
for layer_past in past_key_values:
for layer_past in past:
# get the correct batch idx from decoder layer's batch dim for cross and self-attn
reordered_past += (tuple(_reorder_stacked(past_state, beam_idx) for past_state in layer_past),)
return (past[0], reordered_past)
return reordered_past
def marginalize(self, seq_logits, doc_scores, n_docs=None):
n_docs = n_docs if n_docs is not None else self.config.n_docs
@ -1268,14 +1226,6 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
return_dict=True,
)
if return_dict_in_generate:
# TODO(Patrick): `encoder_outputs`, `past` hack.
# Remove after cleaning encoder-decoder outputs
if output_attentions:
model_kwargs["encoder_attentions"] = encoder_outputs.attentions
if output_hidden_states:
model_kwargs["encoder_hidden_states"] = encoder_outputs.hidden_states
decoder_input_ids = tf.fill(
(batch_size * num_beams, 1),
tf.cast(decoder_start_token_id, tf.int32),
@ -1366,10 +1316,6 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
model_kwargs.pop("output_attentions", None)
model_kwargs.pop("output_scores", None)
# TODO(Patrick): `encoder_outputs`, `past` hack.
# Remove after cleaning encoder-decoder outputs
model_kwargs["past"] = encoder_outputs
return self.greedy_search(
input_ids=decoder_input_ids,
max_length=max_length,

View File

@ -1176,17 +1176,17 @@ class TFRemBertForCausalLM(TFRemBertPreTrainedModel, TFCausalLanguageModelingLos
return self.mlm.predictions
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, inputs, past=None, attention_mask=None, **model_kwargs):
# cut decoder_input_ids if past is used
if past:
inputs = tf.expand_dims(inputs[:, -1], -1)
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = tf.ones(input_shape)
return {
"input_ids": inputs,
"attention_mask": attention_mask,
"past_key_values": past,
"use_cache": model_kwargs["use_cache"],
}
# cut decoder_input_ids if past is used
if past is not None:
input_ids = input_ids[:, -1:]
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
@ -1309,6 +1309,14 @@ class TFRemBertForCausalLM(TFRemBertPreTrainedModel, TFCausalLanguageModelingLos
logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns
)
@staticmethod
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel._reorder_cache
def _reorder_cache(past, beam_idx):
reordered_past = ()
for layer_past in past:
reordered_past += (tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past),)
return reordered_past
@add_start_docstrings(
"""

View File

@ -1209,17 +1209,17 @@ class TFRobertaForCausalLM(TFRobertaPreTrainedModel, TFCausalLanguageModelingLos
return self.name + "/" + self.lm_head.name
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, inputs, past=None, attention_mask=None, **model_kwargs):
# cut decoder_input_ids if past is used
if past:
inputs = tf.expand_dims(inputs[:, -1], -1)
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = tf.ones(input_shape)
return {
"input_ids": inputs,
"attention_mask": attention_mask,
"past_key_values": past,
"use_cache": model_kwargs["use_cache"],
}
# cut decoder_input_ids if past is used
if past is not None:
input_ids = input_ids[:, -1:]
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
@ -1344,6 +1344,14 @@ class TFRobertaForCausalLM(TFRobertaPreTrainedModel, TFCausalLanguageModelingLos
logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns
)
@staticmethod
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel._reorder_cache
def _reorder_cache(past, beam_idx):
reordered_past = ()
for layer_past in past:
reordered_past += (tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past),)
return reordered_past
class TFRobertaClassificationHead(tf.keras.layers.Layer):
"""Head for sentence-level classification tasks."""

View File

@ -1139,7 +1139,7 @@ class TFSpeech2TextDecoder(tf.keras.layers.Layer):
if inputs["output_hidden_states"]:
all_hidden_states += (hidden_states,)
next_cache = (inputs["encoder_hidden_states"], next_decoder_cache) if use_cache else None
next_cache = next_decoder_cache if use_cache else None
if not inputs["return_dict"]:
return hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attns
@ -1571,26 +1571,17 @@ class TFSpeech2TextForConditionalGeneration(TFSpeech2TextPreTrainedModel, TFCaus
decoder_head_mask=None,
cross_attn_head_mask=None,
use_cache=None,
encoder_outputs=None,
**kwargs
):
if past is not None and len(past) <= 2:
if not isinstance(past[0], tf.Tensor):
raise ValueError(f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}")
encoder_outputs = TFBaseModelOutput(last_hidden_state=past[0])
if len(past) == 1:
past_key_values = None
else:
past_key_values = past[1]
if not past_key_values:
raise ValueError(f"decoder cached states must be truthy, got {past_key_values}")
decoder_input_ids = decoder_input_ids[:, -1:]
else:
raise ValueError(f"`past` must be an iterable with length 1 or 2, got {past}")
# cut decoder_input_ids if past is used
if past is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
return {
"input_features": None, # needs to be passed to make Keras.layer.__call__ happy
"encoder_outputs": encoder_outputs,
"past_key_values": past_key_values,
"past_key_values": past,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,
@ -1601,15 +1592,7 @@ class TFSpeech2TextForConditionalGeneration(TFSpeech2TextPreTrainedModel, TFCaus
@staticmethod
def _reorder_cache(past, beam_idx):
if len(past) == 1:
return past
past_key_values = past[1]
reordered_past = ()
for layer_past_key_values in past_key_values:
reordered_past += (
tuple(tf.gather(layer_past_key_value, beam_idx) for layer_past_key_value in layer_past_key_values[:2])
+ layer_past_key_values[2:],
)
return (past[0], reordered_past)
for layer_past in past:
reordered_past += (tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past),)
return reordered_past

View File

@ -1256,15 +1256,13 @@ class TFT5Model(TFT5PreTrainedModel):
return_dict=inputs["return_dict"],
training=inputs["training"],
)
past = decoder_outputs[1] if inputs["use_cache"] else None
if not inputs["return_dict"]:
past = (inputs["encoder_outputs"], decoder_outputs[1]) if inputs["use_cache"] else None
if past is not None:
decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:]
return decoder_outputs + inputs["encoder_outputs"]
past = (inputs["encoder_outputs"].to_tuple(), decoder_outputs[1]) if inputs["use_cache"] else None
return TFSeq2SeqModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state,
past_key_values=past,
@ -1483,8 +1481,8 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], logits)
past = decoder_outputs[1] if inputs["use_cache"] else None
if not inputs["return_dict"]:
past = (inputs["encoder_outputs"], decoder_outputs[1]) if inputs["use_cache"] else None
if past is not None:
decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:]
output = (logits,) + decoder_outputs[1:] + inputs["encoder_outputs"]
@ -1509,8 +1507,6 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
attentions=attentions,
)
past = (inputs["encoder_outputs"].to_tuple(), decoder_outputs[1]) if inputs["use_cache"] else None
return TFSeq2SeqLMOutput(
loss=loss,
logits=logits,
@ -1544,65 +1540,57 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
def prepare_inputs_for_generation(
self,
inputs,
past,
attention_mask,
input_ids,
past=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
use_cache=None,
**kwargs,
encoder_outputs=None,
**kwargs
):
assert past is not None, "past has to be defined for encoder_outputs"
# first step
if len(past) < 2:
encoder_outputs, past_key_values = past, None
else:
encoder_outputs, past_key_values = past[0], past[1]
if "encoder_hidden_states" in kwargs:
encoder_outputs = (*encoder_outputs, kwargs["encoder_hidden_states"])
if "encoder_attentions" in kwargs:
encoder_outputs = (*encoder_outputs, kwargs["encoder_attentions"])
# cut decoder_input_ids if past is used
if past_key_values is not None:
inputs = inputs[:, -1:]
if past is not None:
input_ids = input_ids[:, -1:]
return {
"input_ids": None, # inputs don't have to be defined, but still need to be passed to make Keras.layer.__call__ happy
"decoder_input_ids": inputs, # inputs are the decoder_input_ids
"past_key_values": past_key_values,
"input_ids": None, # needs to be passed to make Keras.layer.__call__ happy
"decoder_input_ids": input_ids,
"past_key_values": past,
"encoder_outputs": encoder_outputs,
"attention_mask": attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"use_cache": use_cache,
}
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
return self._shift_right(labels)
def _reorder_cache(self, past, beam_idx) -> Tuple:
def _reorder_cache(self, past, beam_idx):
# if decoder past is not included in output
# speedy decoding is disabled and no need to reorder
if len(past) < 2:
if past is None:
logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
return past
decoder_past = past[1]
past = (past[0],)
reordered_decoder_past = ()
for layer_past_states in decoder_past:
for layer_past_states in past:
# get the correct batch idx from layer past batch dim
# batch dim of `past` is at 2nd position
reordered_layer_past_states = ()
for layer_past_state in layer_past_states:
# need to set correct `past` for each of the four key / value states
reordered_layer_past_states = reordered_layer_past_states + (tf.gather(layer_past_state, beam_idx),)
reordered_layer_past_states = reordered_layer_past_states + (
tf.gather(layer_past_state, beam_idx, axis=0),
)
assert shape_list(reordered_layer_past_states[0]) == shape_list(layer_past_states[0])
assert reordered_layer_past_states[0].shape == layer_past_states[0].shape
assert len(reordered_layer_past_states) == len(layer_past_states)
reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
return past + (reordered_decoder_past,)
return reordered_decoder_past
@add_start_docstrings(

View File

@ -1058,15 +1058,22 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
attentions=attns,
)
def prepare_inputs_for_generation(self, inputs, past, **model_kwargs):
inputs = {"input_ids": inputs}
def prepare_inputs_for_generation(self, input_ids, past=None, **model_kwargs):
inputs = {}
# if past is defined in model kwargs then use it for faster decoding
if past:
inputs["mems"] = past
inputs["input_ids"] = tf.expand_dims(input_ids[:, -1], axis=-1)
else:
inputs["input_ids"] = input_ids
return inputs
@staticmethod
def _reorder_cache(mems: List[tf.Tensor], beam_idx: tf.Tensor) -> List[tf.Tensor]:
return [tf.gather(layer_past, beam_idx, axis=1) for layer_past in mems]
@add_start_docstrings(
"""

View File

@ -722,45 +722,22 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
cross_attentions=cross_attns,
)
def prepare_inputs_for_generation(self, decoder_input_ids, past, use_cache=None, **kwargs):
if past is None or len(past) not in {1, 2}:
raise ValueError(f"past has to be an iterable of length 1,2 got {past}")
if len(past) == 1:
if not isinstance(past[0], tf.Tensor):
raise ValueError(f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}")
encoder_outputs = TFBaseModelOutput(last_hidden_state=past[0])
past_key_values = None
else:
if len(past) != 2:
raise ValueError(
"`past` has to be of length 2 with the encoder_outputs at the first position and past_key_values at the second position."
)
encoder_outputs, past_key_values = past
if isinstance(encoder_outputs, tuple):
if not isinstance(encoder_outputs[0], tf.Tensor):
raise ValueError(
f"`encoder_outputs[0]` has to be of type `tf.Tensor`, but is {type(encoder_outputs[0])}"
)
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs[0])
elif isinstance(encoder_outputs, tf.Tensor):
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs)
if not past_key_values:
raise ValueError(
f"decoder cached states must be truthy. got {past_key_values} from the 2nd element of past"
)
decoder_input_ids = decoder_input_ids[:, -1:]
if not isinstance(encoder_outputs, TFBaseModelOutput):
raise ValueError(f"encoder_outputs should be a TFBaseModelOutput, Instead got {type(encoder_outputs)}.")
return {
"pixel_values": None, # encoder_outputs is defined. pixel_values not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
def prepare_inputs_for_generation(
self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
):
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past=past)
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
input_dict = {
"pixel_values": None, # needs to be passed to make Keras.layer.__call__ happy
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,
"decoder_input_ids": decoder_inputs["input_ids"],
# TODO (joao): the `TFBaseModelOutput` wrapper should not be needed after the generate refactor is complete
"encoder_outputs": TFBaseModelOutput(last_hidden_state=encoder_outputs[0]),
"past_key_values": decoder_inputs["past_key_values"],
"use_cache": use_cache,
}
return input_dict
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
@ -773,9 +750,4 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
def _reorder_cache(self, past, beam_idx):
# apply decoder cache reordering here
if len(past) == 1:
return past
encoder_outputs, past_key_values = past
return (encoder_outputs, self.decoder._reorder_cache(past_key_values, beam_idx))
return self.decoder._reorder_cache(past, beam_idx)

View File

@ -1246,17 +1246,17 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.lm_loss.name
def prepare_inputs_for_generation(self, inputs, past, use_mems=None, **kwargs):
def prepare_inputs_for_generation(self, inputs, past=None, use_mems=None, **kwargs):
# Add dummy token at the end (no attention on this one)
effective_batch_size = inputs.shape[0]
dummy_token = tf.zeros((effective_batch_size, 1), dtype=inputs.dtype)
# At every pass, the attention values for the new token and the two last generated tokens
# are computed, the rest is reloaded from the `past` cache. A purely auto-regressive model would have
# offset = 1; offset = 2 seems to have slightly better computation.
offset = 2
effective_batch_size = inputs.shape[0]
dummy_token = tf.zeros((effective_batch_size, 1), dtype=inputs.dtype)
if past:
inputs = tf.concat([inputs[:, -offset:], dummy_token], axis=1)
else:
@ -1277,7 +1277,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
"input_ids": inputs,
"perm_mask": perm_mask,
"target_mapping": target_mapping,
"use_mems": kwargs.get("use_mems"),
"use_mems": use_mems,
}
# if past is defined in model kwargs then use it for faster decoding

View File

@ -1777,7 +1777,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte
{% else %}
import random
from typing import Dict, Optional, Tuple, Union
from typing import Optional, Tuple, Union
import tensorflow as tf
@ -2736,9 +2736,6 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
if inputs["output_hidden_states"]:
all_hidden_states += (hidden_states,)
if inputs["use_cache"]:
present_key_values = (inputs["encoder_hidden_states"], present_key_values)
if not inputs["return_dict"]:
return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns
else:
@ -3186,43 +3183,23 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past,
attention_mask,
past=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
use_cache=False,
use_cache=None,
encoder_outputs=None,
**kwargs
) -> Dict:
assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}"
if len(past) == 1:
assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}"
encoder_outputs = TFBaseModelOutput(last_hidden_state=past[0])
past_key_values = None
else:
assert (
len(past) == 2
), "`past` has to be of length 2 with the encoder_outputs at the first position and past_key_values at the second position."
encoder_outputs, past_key_values = past
if isinstance(encoder_outputs, tuple):
assert isinstance(
encoder_outputs[0], tf.Tensor
), f"`encoder_outputs[0]` has to be of type `tf.Tensor`, but is {type(encoder_outputs[0])}"
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs[0])
elif isinstance(encoder_outputs, tf.Tensor):
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs)
assert (
past_key_values
), f"decoder cached states must be truthy. got {past_key_values} from the 2nd element of past"
):
# cut decoder_input_ids if past is used
if past is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
assert isinstance(
encoder_outputs, TFBaseModelOutput
), f"encoder_outputs should be a TFBaseModelOutput, Instead got {type(encoder_outputs)}."
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"input_ids": None, # needs to be passed to make Keras.layer.__call__ happy
"encoder_outputs": encoder_outputs,
"past_key_values": past_key_values,
"past_key_values": past,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,
@ -3233,17 +3210,10 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
@staticmethod
def _reorder_cache(past, beam_idx):
if len(past) == 1:
return past
past_key_values = past[1]
reordered_past = ()
for layer_past_key_values in past_key_values:
reordered_past += (
tuple(tf.gather(layer_past_key_value, beam_idx) for layer_past_key_value in layer_past_key_values[:2]) + layer_past_key_values[2:],
)
return (past[0], reordered_past)
for layer_past in past:
reordered_past += (tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past),)
return reordered_past
def hf_compute_loss(self, labels, logits):
"""CrossEntropyLoss that ignores pad tokens"""

View File

@ -802,7 +802,6 @@ class TF{{cookiecutter.camelcase_modelname}}ModelTester:
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
output, past_key_values = outputs.to_tuple()
past_key_values = past_key_values[1]
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)

View File

@ -116,7 +116,6 @@ class TFBartModelTester:
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
output, past_key_values = outputs.to_tuple()
past_key_values = past_key_values[1]
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)

View File

@ -114,7 +114,6 @@ class TFBlenderbotModelTester:
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
output, past_key_values = outputs.to_tuple()
past_key_values = past_key_values[1]
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)

View File

@ -114,7 +114,6 @@ class TFBlenderbotSmallModelTester:
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
output, past_key_values = outputs.to_tuple()
past_key_values = past_key_values[1]
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)

View File

@ -133,7 +133,6 @@ class TFLEDModelTester:
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
output, past_key_values = outputs.to_tuple()
past_key_values = past_key_values[1]
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)

View File

@ -116,7 +116,6 @@ class TFMarianModelTester:
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
output, past_key_values = outputs.to_tuple()
past_key_values = past_key_values[1]
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)

View File

@ -114,7 +114,6 @@ class TFPegasusModelTester:
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
output, past_key_values = outputs.to_tuple()
past_key_values = past_key_values[1]
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)

View File

@ -182,7 +182,7 @@ class TFSpeech2TextModelTester:
# first forward pass
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
_, (_, past_key_values) = outputs.to_tuple()
_, past_key_values = outputs.to_tuple()
# create hypothetical multiple next token and extent to next_input_ids
next_tokens = tf.math.maximum(ids_tensor((self.batch_size, 3), config.vocab_size), 2)

View File

@ -98,13 +98,10 @@ class TFT5ModelTester:
encoder_output = result.encoder_last_hidden_state
self.parent.assertListEqual(list(encoder_output.shape), [self.batch_size, self.seq_length, self.hidden_size])
self.parent.assertListEqual(list(decoder_output.shape), [self.batch_size, self.seq_length, self.hidden_size])
self.parent.assertEqual(len(decoder_past), 2)
# decoder_past[0] should correspond to encoder output
self.parent.assertTrue(tf.reduce_all(tf.math.equal(decoder_past[0][0], encoder_output)))
# There should be `num_layers` key value embeddings stored in decoder_past[1]
self.parent.assertEqual(len(decoder_past[1]), config.num_layers)
self.parent.assertEqual(len(decoder_past), config.num_layers)
# There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past[1] tuple
self.parent.assertEqual(len(decoder_past[1][0]), 4)
self.parent.assertEqual(len(decoder_past[0]), 4)
def create_and_check_t5_with_lm_head(self, config, input_ids, input_mask, token_labels):
model = TFT5ForConditionalGeneration(config=config)