mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
TF generate refactor - past without encoder outputs (#15944)
* Remove packed past from generation_tf_utils * update models with the new past format * update template accordingly
This commit is contained in:
parent
62d847602a
commit
70203b5937
@ -867,9 +867,8 @@ class TFGenerationMixin:
|
||||
|
||||
beam_scores = tf.reshape(beam_scores, (batch_size * num_beams,))
|
||||
|
||||
# cache compute states
|
||||
past = encoder_outputs
|
||||
# to stay similar to torch : past = (encoder_outputs, None) if encoder_outputs is not None else None
|
||||
# variable to cache compute states
|
||||
past = None
|
||||
|
||||
# init attention / hidden states / scores tuples
|
||||
scores = () if (return_dict_in_generate and kwargs["output_scores"]) else None
|
||||
@ -886,6 +885,13 @@ class TFGenerationMixin:
|
||||
if (return_dict_in_generate and kwargs["encoder_hidden_states"])
|
||||
else None
|
||||
)
|
||||
# the refactored generate, without the encoder outputs in `past`, expects the `encoder_outputs`
|
||||
# variable to contain all (encoder_outputs, encoder_hidden_states, encoder_attentions) in
|
||||
# `prepare_inputs_for_generation`
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_outputs = (*encoder_outputs, encoder_hidden_states)
|
||||
if encoder_attentions is not None:
|
||||
encoder_outputs = (*encoder_outputs, encoder_attentions)
|
||||
|
||||
# done sentences
|
||||
done = [False for _ in range(batch_size)]
|
||||
@ -896,6 +902,7 @@ class TFGenerationMixin:
|
||||
past=past,
|
||||
attention_mask=attention_mask,
|
||||
use_cache=use_cache,
|
||||
encoder_outputs=encoder_outputs,
|
||||
**kwargs,
|
||||
)
|
||||
outputs = self(
|
||||
@ -1486,14 +1493,10 @@ class TFGenerationMixin:
|
||||
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask:
|
||||
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(input_ids, pad_token_id)
|
||||
|
||||
# 4. Prepare model inputs which will be used for auto-regressive generation
|
||||
if self.config.is_encoder_decoder:
|
||||
# if model is encoder decoder model, we create encoder_outputs and add to `model_kwargs`
|
||||
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
|
||||
input_ids, return_dict_in_generate, model_kwargs
|
||||
)
|
||||
|
||||
# 4. Prepare `input_ids` which will be used for auto-regressive generation
|
||||
if self.config.is_encoder_decoder:
|
||||
# if encoder-decoder, we create encoder_outputs and add to `model_kwargs`
|
||||
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, model_kwargs)
|
||||
# if encoder-decoder then `input_ids` come from `decoder_start_token_id`
|
||||
input_ids = self._prepare_decoder_input_ids_for_generation(
|
||||
batch_size,
|
||||
@ -1531,10 +1534,6 @@ class TFGenerationMixin:
|
||||
f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search."
|
||||
)
|
||||
|
||||
# TODO(Patrick) - ugly `past`/`encoder_output` hack here which requires a bigger refactor of all
|
||||
# generation models in TF. `past` should be optional everywhere and not be set equal to encoder_outputs.
|
||||
model_kwargs["past"] = model_kwargs.get("encoder_outputs")[:1] if self.config.is_encoder_decoder else None
|
||||
|
||||
# 8. run greedy search
|
||||
return self.greedy_search(
|
||||
input_ids,
|
||||
@ -1559,10 +1558,6 @@ class TFGenerationMixin:
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
# TODO(Patrick) - ugly `past`/`encoder_output` hack here which requires a bigger refactor of all
|
||||
# generation models in TF. `past` should be optional everywhere and not be set equal to encoder_outputs.
|
||||
model_kwargs["past"] = model_kwargs.get("encoder_outputs")[:1] if self.config.is_encoder_decoder else None
|
||||
|
||||
# 10. run sample
|
||||
return self.sample(
|
||||
input_ids,
|
||||
@ -1589,12 +1584,7 @@ class TFGenerationMixin:
|
||||
else:
|
||||
return tf.ones(input_ids.shape[:2], dtype=tf.int32)
|
||||
|
||||
def _prepare_encoder_decoder_kwargs_for_generation(
|
||||
self, input_ids: tf.Tensor, return_dict_in_generate, model_kwargs
|
||||
) -> Dict[str, Any]:
|
||||
# TODO(Patrick) - remove `return_dict_in_generate` flag input once `past`/`encoder_outputs`
|
||||
# is cleaned
|
||||
|
||||
def _prepare_encoder_decoder_kwargs_for_generation(self, input_ids: tf.Tensor, model_kwargs) -> Dict[str, Any]:
|
||||
# get encoder and store encoder outputs
|
||||
encoder = self.get_encoder()
|
||||
|
||||
@ -1612,17 +1602,8 @@ class TFGenerationMixin:
|
||||
encoder_kwargs.pop("attention_mask")
|
||||
|
||||
encoder_outputs = encoder(input_ids, **encoder_kwargs)
|
||||
|
||||
model_kwargs["encoder_outputs"] = encoder_outputs
|
||||
|
||||
# TODO(Patrick): `encoder_outputs`, `past` hack. Currently, `encoder_attentions` and
|
||||
# `encoder_hidden_states` have to be seperated from encoder_outputs and passed
|
||||
# under other names because of `encoder_outputs`, `past` hack. Need to clean-up
|
||||
# all encoder-decoder prepare_inputs_for_generation method to clean this
|
||||
if return_dict_in_generate:
|
||||
model_kwargs["encoder_attentions"] = encoder_outputs.get("attentions", None)
|
||||
model_kwargs["encoder_hidden_states"] = encoder_outputs.get("hidden_states", None)
|
||||
|
||||
return model_kwargs
|
||||
|
||||
def _prepare_decoder_input_ids_for_generation(
|
||||
@ -1712,27 +1693,17 @@ class TFGenerationMixin:
|
||||
|
||||
return inputs
|
||||
|
||||
@staticmethod
|
||||
def _update_model_kwargs_for_generation(
|
||||
self, outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False
|
||||
outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
# update past
|
||||
if self._use_cache(outputs, model_kwargs["use_cache"]):
|
||||
# TODO(Patrick): `past`/`encoder_outputs` hack. This should be
|
||||
# removed when cleaning up the encoder-decoder models
|
||||
# if model has past, then set the past variable to speed up decoding
|
||||
# make this method static then as well
|
||||
model_kwargs["past"] = outputs[1]
|
||||
elif "past_key_values" in outputs:
|
||||
if "past_key_values" in outputs:
|
||||
model_kwargs["past"] = outputs.past_key_values
|
||||
elif "mems" in outputs:
|
||||
model_kwargs["past"] = outputs.mems
|
||||
elif "past_buckets_states" in outputs:
|
||||
model_kwargs["past"] = outputs.past_buckets_states
|
||||
elif "past" in model_kwargs:
|
||||
# TODO(Patrick) `past`/`encoder_outputs` hack.
|
||||
# removed when cleaning up the encoder-decoder models.
|
||||
# The line should not be necessary.
|
||||
pass
|
||||
else:
|
||||
model_kwargs["past"] = None
|
||||
|
||||
@ -1907,26 +1878,18 @@ class TFGenerationMixin:
|
||||
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
||||
|
||||
# TODO(Patrick): `encoder_outputs`, `past` hack. Currently T5, Bart expect `encoder_outputs`
|
||||
# to be wrapped into `past` variable. Tis is a bad design and needs
|
||||
# to be updated.
|
||||
# Remove the following lines when updating all encoder-decoder models
|
||||
encoder_outputs = model_kwargs.pop("encoder_outputs", None)
|
||||
|
||||
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
||||
if return_dict_in_generate and self.config.is_encoder_decoder:
|
||||
encoder_attentions = encoder_outputs.get("attentions") if output_attentions else None
|
||||
encoder_hidden_states = encoder_outputs.get("hidden_states") if output_hidden_states else None
|
||||
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
|
||||
encoder_hidden_states = (
|
||||
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
|
||||
)
|
||||
|
||||
# keep track of which sequences are already finished
|
||||
unfinished_sequences = tf.ones_like(input_ids[:, 0])
|
||||
cur_len = input_ids.shape[-1]
|
||||
|
||||
while cur_len < max_length:
|
||||
# TODO(Patrick): remove following line by cleaning up `prepare_inputs_for_generation`
|
||||
# in all models
|
||||
model_kwargs["use_cache"] = None if "use_cache" not in model_kwargs else model_kwargs["use_cache"]
|
||||
|
||||
# prepare model inputs
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
||||
|
||||
@ -2129,25 +2092,18 @@ class TFGenerationMixin:
|
||||
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
||||
|
||||
# TODO(Patrick): `encoder_outputs`, `past` hack. Currently T5, Bart expect `encoder_outputs`
|
||||
# to be wrapped into `past` variable. This is a bad design and needs to be updated.
|
||||
# Remove the following lines when updating all encoder-decoder models
|
||||
encoder_outputs = model_kwargs.pop("encoder_outputs", None)
|
||||
|
||||
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
||||
if return_dict_in_generate and self.config.is_encoder_decoder:
|
||||
encoder_attentions = encoder_outputs.get("attentions") if output_attentions else None
|
||||
encoder_hidden_states = encoder_outputs.get("hidden_states") if output_hidden_states else None
|
||||
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
|
||||
encoder_hidden_states = (
|
||||
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
|
||||
)
|
||||
|
||||
# keep track of which sequences are already finished
|
||||
unfinished_sequences = tf.ones_like(input_ids[:, 0])
|
||||
cur_len = input_ids.shape[-1]
|
||||
|
||||
while cur_len < max_length:
|
||||
# TODO(Patrick): remove following line by cleaning up `prepare_inputs_for_generation`
|
||||
# in all models
|
||||
model_kwargs["use_cache"] = None if "use_cache" not in model_kwargs else model_kwargs["use_cache"]
|
||||
|
||||
# prepare model inputs
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
||||
|
||||
|
@ -16,7 +16,7 @@
|
||||
|
||||
|
||||
import random
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
@ -1012,9 +1012,6 @@ class TFBartDecoder(tf.keras.layers.Layer):
|
||||
if inputs["output_hidden_states"]:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if inputs["use_cache"]:
|
||||
present_key_values = (inputs["encoder_hidden_states"], present_key_values)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns
|
||||
else:
|
||||
@ -1449,43 +1446,23 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past,
|
||||
attention_mask,
|
||||
past=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
use_cache=None,
|
||||
**kwargs,
|
||||
) -> Dict:
|
||||
assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}"
|
||||
if len(past) == 1:
|
||||
assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}"
|
||||
encoder_outputs = TFBaseModelOutput(last_hidden_state=past[0])
|
||||
past_key_values = None
|
||||
else:
|
||||
assert (
|
||||
len(past) == 2
|
||||
), "`past` has to be of length 2 with the encoder_outputs at the first position and past_key_values at the second position."
|
||||
encoder_outputs, past_key_values = past
|
||||
if isinstance(encoder_outputs, tuple):
|
||||
assert isinstance(
|
||||
encoder_outputs[0], tf.Tensor
|
||||
), f"`encoder_outputs[0]` has to be of type `tf.Tensor`, but is {type(encoder_outputs[0])}"
|
||||
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs[0])
|
||||
elif isinstance(encoder_outputs, tf.Tensor):
|
||||
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs)
|
||||
assert (
|
||||
past_key_values
|
||||
), f"decoder cached states must be truthy. got {past_key_values} from the 2nd element of past"
|
||||
encoder_outputs=None,
|
||||
**kwargs
|
||||
):
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
|
||||
assert isinstance(
|
||||
encoder_outputs, TFBaseModelOutput
|
||||
), f"encoder_outputs should be a TFBaseModelOutput, Instead got {type(encoder_outputs)}."
|
||||
return {
|
||||
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past_key_values,
|
||||
"past_key_values": past,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
@ -1499,15 +1476,10 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past, beam_idx):
|
||||
if len(past) == 1:
|
||||
return past
|
||||
|
||||
past_key_values = past[1]
|
||||
|
||||
reordered_past = ()
|
||||
for layer_past_key_values in past_key_values:
|
||||
for layer_past in past:
|
||||
# cached cross_attention states don't have to be reordered -> they are always the same
|
||||
reordered_past += (
|
||||
tuple(tf.gather(layer_past_key_value, beam_idx) for layer_past_key_value in layer_past_key_values[:2])
|
||||
+ layer_past_key_values[2:],
|
||||
tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:],
|
||||
)
|
||||
return (past[0], reordered_past)
|
||||
return reordered_past
|
||||
|
@ -1443,17 +1443,17 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
|
||||
return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name
|
||||
|
||||
def prepare_inputs_for_generation(self, inputs, past=None, attention_mask=None, **model_kwargs):
|
||||
# cut decoder_input_ids if past is used
|
||||
if past:
|
||||
inputs = tf.expand_dims(inputs[:, -1], -1)
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
|
||||
input_shape = input_ids.shape
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
if attention_mask is None:
|
||||
attention_mask = tf.ones(input_shape)
|
||||
|
||||
return {
|
||||
"input_ids": inputs,
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": past,
|
||||
"use_cache": model_kwargs["use_cache"],
|
||||
}
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
input_ids = input_ids[:, -1:]
|
||||
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
|
||||
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_TOKENIZER_FOR_DOC,
|
||||
@ -1575,6 +1575,13 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past:
|
||||
reordered_past += (tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past),)
|
||||
return reordered_past
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""Bert Model with a `next sentence prediction (classification)` head on top.""",
|
||||
|
@ -18,7 +18,7 @@
|
||||
import os
|
||||
import random
|
||||
import warnings
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
@ -1011,9 +1011,6 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
|
||||
if inputs["output_hidden_states"]:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if inputs["use_cache"]:
|
||||
present_key_values = (inputs["encoder_hidden_states"], present_key_values)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns
|
||||
else:
|
||||
@ -1461,43 +1458,23 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past,
|
||||
attention_mask,
|
||||
past=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
use_cache=None,
|
||||
**kwargs,
|
||||
) -> Dict:
|
||||
assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}"
|
||||
if len(past) == 1:
|
||||
assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}"
|
||||
encoder_outputs = TFBaseModelOutput(last_hidden_state=past[0])
|
||||
past_key_values = None
|
||||
else:
|
||||
assert (
|
||||
len(past) == 2
|
||||
), "`past` has to be of length 2 with the encoder_outputs at the first position and past_key_values at the second position."
|
||||
encoder_outputs, past_key_values = past
|
||||
if isinstance(encoder_outputs, tuple):
|
||||
assert isinstance(
|
||||
encoder_outputs[0], tf.Tensor
|
||||
), f"`encoder_outputs[0]` has to be of type `tf.Tensor`, but is {type(encoder_outputs[0])}"
|
||||
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs[0])
|
||||
elif isinstance(encoder_outputs, tf.Tensor):
|
||||
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs)
|
||||
assert (
|
||||
past_key_values
|
||||
), f"decoder cached states must be truthy. got {past_key_values} from the 2nd element of past"
|
||||
encoder_outputs=None,
|
||||
**kwargs
|
||||
):
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
|
||||
assert isinstance(
|
||||
encoder_outputs, TFBaseModelOutput
|
||||
), f"encoder_outputs should be a TFBaseModelOutput, Instead got {type(encoder_outputs)}."
|
||||
return {
|
||||
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past_key_values,
|
||||
"past_key_values": past,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
@ -1509,15 +1486,10 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal
|
||||
@staticmethod
|
||||
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration._reorder_cache
|
||||
def _reorder_cache(past, beam_idx):
|
||||
if len(past) == 1:
|
||||
return past
|
||||
|
||||
past_key_values = past[1]
|
||||
|
||||
reordered_past = ()
|
||||
for layer_past_key_values in past_key_values:
|
||||
for layer_past in past:
|
||||
# cached cross_attention states don't have to be reordered -> they are always the same
|
||||
reordered_past += (
|
||||
tuple(tf.gather(layer_past_key_value, beam_idx) for layer_past_key_value in layer_past_key_values[:2])
|
||||
+ layer_past_key_values[2:],
|
||||
tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:],
|
||||
)
|
||||
return (past[0], reordered_past)
|
||||
return reordered_past
|
||||
|
@ -16,7 +16,7 @@
|
||||
|
||||
|
||||
import random
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
@ -1010,9 +1010,6 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
|
||||
if inputs["output_hidden_states"]:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if inputs["use_cache"]:
|
||||
present_key_values = (inputs["encoder_hidden_states"], present_key_values)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns
|
||||
else:
|
||||
@ -1434,43 +1431,23 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past,
|
||||
attention_mask,
|
||||
past=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
use_cache=None,
|
||||
**kwargs,
|
||||
) -> Dict:
|
||||
assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}"
|
||||
if len(past) == 1:
|
||||
assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}"
|
||||
encoder_outputs = TFBaseModelOutput(last_hidden_state=past[0])
|
||||
past_key_values = None
|
||||
else:
|
||||
assert (
|
||||
len(past) == 2
|
||||
), "`past` has to be of length 2 with the encoder_outputs at the first position and past_key_values at the second position."
|
||||
encoder_outputs, past_key_values = past
|
||||
if isinstance(encoder_outputs, tuple):
|
||||
assert isinstance(
|
||||
encoder_outputs[0], tf.Tensor
|
||||
), f"`encoder_outputs[0]` has to be of type `tf.Tensor`, but is {type(encoder_outputs[0])}"
|
||||
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs[0])
|
||||
elif isinstance(encoder_outputs, tf.Tensor):
|
||||
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs)
|
||||
assert (
|
||||
past_key_values
|
||||
), f"decoder cached states must be truthy. got {past_key_values} from the 2nd element of past"
|
||||
encoder_outputs=None,
|
||||
**kwargs
|
||||
):
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
|
||||
assert isinstance(
|
||||
encoder_outputs, TFBaseModelOutput
|
||||
), f"encoder_outputs should be a TFBaseModelOutput, Instead got {type(encoder_outputs)}."
|
||||
return {
|
||||
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past_key_values,
|
||||
"past_key_values": past,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
@ -1482,15 +1459,10 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
|
||||
@staticmethod
|
||||
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration._reorder_cache
|
||||
def _reorder_cache(past, beam_idx):
|
||||
if len(past) == 1:
|
||||
return past
|
||||
|
||||
past_key_values = past[1]
|
||||
|
||||
reordered_past = ()
|
||||
for layer_past_key_values in past_key_values:
|
||||
for layer_past in past:
|
||||
# cached cross_attention states don't have to be reordered -> they are always the same
|
||||
reordered_past += (
|
||||
tuple(tf.gather(layer_past_key_value, beam_idx) for layer_past_key_value in layer_past_key_values[:2])
|
||||
+ layer_past_key_values[2:],
|
||||
tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:],
|
||||
)
|
||||
return (past[0], reordered_past)
|
||||
return reordered_past
|
||||
|
@ -16,6 +16,7 @@
|
||||
""" TF 2.0 CTRL model."""
|
||||
|
||||
import warnings
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
@ -659,12 +660,12 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
|
||||
return self.name + "/" + self.lm_head.name
|
||||
|
||||
def prepare_inputs_for_generation(self, inputs, past, **kwargs):
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, use_cache=None, **kwargs):
|
||||
# only last token for inputs_ids if past is defined in kwargs
|
||||
if past:
|
||||
inputs = tf.expand_dims(inputs[:, -1], -1)
|
||||
input_ids = tf.expand_dims(input_ids[:, -1], -1)
|
||||
|
||||
return {"input_ids": inputs, "past": past, "use_cache": kwargs["use_cache"]}
|
||||
return {"input_ids": input_ids, "past_key_values": past, "use_cache": use_cache}
|
||||
|
||||
@add_start_docstrings_to_model_forward(CTRL_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
@ -758,6 +759,12 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
|
||||
return TFCausalLMOutputWithPast(logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns)
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past: Tuple[Tuple[tf.Tensor]], beam_idx: tf.Tensor) -> Tuple[Tuple[tf.Tensor]]:
|
||||
return tuple(
|
||||
tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past) for layer_past in past
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
|
@ -692,52 +692,21 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past,
|
||||
attention_mask,
|
||||
use_cache=None,
|
||||
**kwargs,
|
||||
self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
|
||||
):
|
||||
if past is None or len(past) not in {1, 2}:
|
||||
raise ValueError(f"past has to be an iterable of length 1,2 got {past}")
|
||||
|
||||
if len(past) == 1:
|
||||
if not isinstance(past[0], tf.Tensor):
|
||||
raise ValueError(f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}")
|
||||
encoder_outputs = TFBaseModelOutput(last_hidden_state=past[0])
|
||||
past_key_values = None
|
||||
else:
|
||||
if len(past) != 2:
|
||||
raise ValueError(
|
||||
"`past` has to be of length 2 with the encoder_outputs at the first position and past_key_values at the second position."
|
||||
)
|
||||
encoder_outputs, past_key_values = past
|
||||
if isinstance(encoder_outputs, tuple):
|
||||
if not isinstance(encoder_outputs[0], tf.Tensor):
|
||||
raise ValueError(
|
||||
f"`encoder_outputs[0]` has to be of type `tf.Tensor`, but is {type(encoder_outputs[0])}"
|
||||
)
|
||||
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs[0])
|
||||
elif isinstance(encoder_outputs, tf.Tensor):
|
||||
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs)
|
||||
if not past_key_values:
|
||||
raise ValueError(
|
||||
f"decoder cached states must be truthy. got {past_key_values} from the 2nd element of past"
|
||||
)
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
|
||||
if not isinstance(encoder_outputs, TFBaseModelOutput):
|
||||
raise ValueError(f"encoder_outputs should be a TFBaseModelOutput, Instead got {type(encoder_outputs)}.")
|
||||
|
||||
return {
|
||||
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past_key_values,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past=past)
|
||||
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
|
||||
input_dict = {
|
||||
"input_ids": None, # needs to be passed to make Keras.layer.__call__ happy
|
||||
"attention_mask": attention_mask,
|
||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"decoder_input_ids": decoder_inputs["input_ids"],
|
||||
# TODO (joao): the `TFBaseModelOutput` wrapper should not be needed after the generate refactor is complete
|
||||
"encoder_outputs": TFBaseModelOutput(last_hidden_state=encoder_outputs[0]),
|
||||
"past_key_values": decoder_inputs["past_key_values"],
|
||||
"use_cache": use_cache,
|
||||
}
|
||||
return input_dict
|
||||
|
||||
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
|
||||
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
|
||||
@ -750,9 +719,4 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
|
||||
def _reorder_cache(self, past, beam_idx):
|
||||
# apply decoder cache reordering here
|
||||
if len(past) == 1:
|
||||
return past
|
||||
|
||||
encoder_outputs, past_key_values = past
|
||||
|
||||
return (encoder_outputs, self.decoder._reorder_cache(past_key_values, beam_idx))
|
||||
return self.decoder._reorder_cache(past, beam_idx)
|
||||
|
@ -851,12 +851,15 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
def set_output_embeddings(self, value):
|
||||
self.set_input_embeddings(value)
|
||||
|
||||
def prepare_inputs_for_generation(self, inputs, past, **kwargs):
|
||||
def prepare_inputs_for_generation(self, inputs, past=None, use_cache=None, **kwargs):
|
||||
# TODO: (Joao) after the TF generator is complete, update GPT2 TF generation to match PT's. NB -- some GPT2
|
||||
# tests will need to be fixed after the change
|
||||
|
||||
# only last token for inputs_ids if past is defined in kwargs
|
||||
if past:
|
||||
inputs = tf.expand_dims(inputs[:, -1], -1)
|
||||
|
||||
return {"input_ids": inputs, "past": past, "use_cache": kwargs["use_cache"]}
|
||||
return {"input_ids": inputs, "past_key_values": past, "use_cache": use_cache}
|
||||
|
||||
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
|
@ -17,7 +17,7 @@
|
||||
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
@ -2097,7 +2097,7 @@ class TFLEDDecoder(tf.keras.layers.Layer):
|
||||
all_self_attns = all_self_attns if inputs["output_attentions"] else None
|
||||
all_cross_attentions = all_cross_attentions if inputs["output_attentions"] else None
|
||||
|
||||
present_key_values = (encoder_hidden_states, present_key_values) if inputs["use_cache"] else None
|
||||
present_key_values = present_key_values if inputs["use_cache"] else None
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
return tuple(
|
||||
@ -2527,45 +2527,26 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past,
|
||||
attention_mask,
|
||||
past=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
use_cache=None,
|
||||
encoder_outputs=None,
|
||||
**kwargs,
|
||||
) -> Dict:
|
||||
assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}"
|
||||
if len(past) == 1:
|
||||
assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}"
|
||||
encoder_outputs = TFLEDEncoderBaseModelOutput(last_hidden_state=past[0])
|
||||
past_key_values = None
|
||||
else:
|
||||
assert (
|
||||
len(past) == 2
|
||||
), "`past` has to be of length 2 with the encoder_outputs at the first position and past_key_values at the second position."
|
||||
encoder_outputs, past_key_values = past
|
||||
if isinstance(encoder_outputs, tuple):
|
||||
assert isinstance(
|
||||
encoder_outputs[0], tf.Tensor
|
||||
), f"`encoder_outputs[0]` has to be of type `tf.Tensor`, but is {type(encoder_outputs[0])}"
|
||||
encoder_outputs = TFLEDEncoderBaseModelOutput(last_hidden_state=encoder_outputs[0])
|
||||
elif isinstance(encoder_outputs, tf.Tensor):
|
||||
encoder_outputs = TFLEDEncoderBaseModelOutput(last_hidden_state=encoder_outputs)
|
||||
assert (
|
||||
past_key_values
|
||||
), f"decoder cached states must be truthy. got {past_key_values} from the 2nd element of past"
|
||||
):
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
|
||||
assert isinstance(
|
||||
encoder_outputs,
|
||||
TFLEDEncoderBaseModelOutput,
|
||||
), f"encoder_outputs should be a TFLEDEncoderBaseModelOutput, Instead got {type(encoder_outputs)}."
|
||||
return {
|
||||
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past_key_values,
|
||||
"past_key_values": past,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||
}
|
||||
|
||||
@ -2574,18 +2555,13 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past, beam_idx):
|
||||
if len(past) == 1:
|
||||
return past
|
||||
|
||||
past_key_values = past[1]
|
||||
|
||||
reordered_past = ()
|
||||
for layer_past_key_values in past_key_values:
|
||||
for layer_past in past:
|
||||
# cached cross_attention states don't have to be reordered -> they are always the same
|
||||
reordered_past += (
|
||||
tuple(tf.gather(layer_past_key_value, beam_idx) for layer_past_key_value in layer_past_key_values[:2])
|
||||
+ layer_past_key_values[2:],
|
||||
tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:],
|
||||
)
|
||||
return (past[0], reordered_past)
|
||||
return reordered_past
|
||||
|
||||
def hf_compute_loss(self, labels, logits):
|
||||
"""CrossEntropyLoss that ignores pad tokens"""
|
||||
|
@ -16,7 +16,7 @@
|
||||
|
||||
|
||||
import random
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
@ -1050,9 +1050,6 @@ class TFMarianDecoder(tf.keras.layers.Layer):
|
||||
if inputs["output_hidden_states"]:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if inputs["use_cache"]:
|
||||
present_key_values = (inputs["encoder_hidden_states"], present_key_values)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns
|
||||
else:
|
||||
@ -1477,43 +1474,23 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past,
|
||||
attention_mask,
|
||||
past=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
use_cache=None,
|
||||
**kwargs,
|
||||
) -> Dict:
|
||||
assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}"
|
||||
if len(past) == 1:
|
||||
assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}"
|
||||
encoder_outputs = TFBaseModelOutput(last_hidden_state=past[0])
|
||||
past_key_values = None
|
||||
else:
|
||||
assert (
|
||||
len(past) == 2
|
||||
), "`past` has to be of length 2 with the encoder_outputs at the first position and past_key_values at the second position."
|
||||
encoder_outputs, past_key_values = past
|
||||
if isinstance(encoder_outputs, tuple):
|
||||
assert isinstance(
|
||||
encoder_outputs[0], tf.Tensor
|
||||
), f"`encoder_outputs[0]` has to be of type `tf.Tensor`, but is {type(encoder_outputs[0])}"
|
||||
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs[0])
|
||||
elif isinstance(encoder_outputs, tf.Tensor):
|
||||
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs)
|
||||
assert (
|
||||
past_key_values
|
||||
), f"decoder cached states must be truthy. got {past_key_values} from the 2nd element of past"
|
||||
encoder_outputs=None,
|
||||
**kwargs
|
||||
):
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
|
||||
assert isinstance(
|
||||
encoder_outputs, TFBaseModelOutput
|
||||
), f"encoder_outputs should be a TFBaseModelOutput, Instead got {type(encoder_outputs)}."
|
||||
return {
|
||||
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past_key_values,
|
||||
"past_key_values": past,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
@ -1528,18 +1505,13 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
@staticmethod
|
||||
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration._reorder_cache
|
||||
def _reorder_cache(past, beam_idx):
|
||||
if len(past) == 1:
|
||||
return past
|
||||
|
||||
past_key_values = past[1]
|
||||
|
||||
reordered_past = ()
|
||||
for layer_past_key_values in past_key_values:
|
||||
for layer_past in past:
|
||||
# cached cross_attention states don't have to be reordered -> they are always the same
|
||||
reordered_past += (
|
||||
tuple(tf.gather(layer_past_key_value, beam_idx) for layer_past_key_value in layer_past_key_values[:2])
|
||||
+ layer_past_key_values[2:],
|
||||
tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:],
|
||||
)
|
||||
return (past[0], reordered_past)
|
||||
return reordered_past
|
||||
|
||||
def adjust_logits_during_generation(
|
||||
self, logits, cur_len, max_length, forced_bos_token_id, forced_eos_token_id, **kwargs
|
||||
|
@ -16,7 +16,7 @@
|
||||
|
||||
|
||||
import random
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
@ -1034,9 +1034,6 @@ class TFMBartDecoder(tf.keras.layers.Layer):
|
||||
if inputs["output_hidden_states"]:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if inputs["use_cache"]:
|
||||
present_key_values = (inputs["encoder_hidden_states"], present_key_values)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns
|
||||
else:
|
||||
@ -1462,43 +1459,23 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past,
|
||||
attention_mask,
|
||||
past=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
use_cache=None,
|
||||
**kwargs,
|
||||
) -> Dict:
|
||||
assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}"
|
||||
if len(past) == 1:
|
||||
assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}"
|
||||
encoder_outputs = TFBaseModelOutput(last_hidden_state=past[0])
|
||||
past_key_values = None
|
||||
else:
|
||||
assert (
|
||||
len(past) == 2
|
||||
), "`past` has to be of length 2 with the encoder_outputs at the first position and past_key_values at the second position."
|
||||
encoder_outputs, past_key_values = past
|
||||
if isinstance(encoder_outputs, tuple):
|
||||
assert isinstance(
|
||||
encoder_outputs[0], tf.Tensor
|
||||
), f"`encoder_outputs[0]` has to be of type `tf.Tensor`, but is {type(encoder_outputs[0])}"
|
||||
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs[0])
|
||||
elif isinstance(encoder_outputs, tf.Tensor):
|
||||
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs)
|
||||
assert (
|
||||
past_key_values
|
||||
), f"decoder cached states must be truthy. got {past_key_values} from the 2nd element of past"
|
||||
encoder_outputs=None,
|
||||
**kwargs
|
||||
):
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
|
||||
assert isinstance(
|
||||
encoder_outputs, TFBaseModelOutput
|
||||
), f"encoder_outputs should be a TFBaseModelOutput, Instead got {type(encoder_outputs)}."
|
||||
return {
|
||||
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past_key_values,
|
||||
"past_key_values": past,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
@ -1513,15 +1490,10 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
|
||||
@staticmethod
|
||||
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration._reorder_cache
|
||||
def _reorder_cache(past, beam_idx):
|
||||
if len(past) == 1:
|
||||
return past
|
||||
|
||||
past_key_values = past[1]
|
||||
|
||||
reordered_past = ()
|
||||
for layer_past_key_values in past_key_values:
|
||||
for layer_past in past:
|
||||
# cached cross_attention states don't have to be reordered -> they are always the same
|
||||
reordered_past += (
|
||||
tuple(tf.gather(layer_past_key_value, beam_idx) for layer_past_key_value in layer_past_key_values[:2])
|
||||
+ layer_past_key_values[2:],
|
||||
tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:],
|
||||
)
|
||||
return (past[0], reordered_past)
|
||||
return reordered_past
|
||||
|
@ -16,7 +16,7 @@
|
||||
|
||||
|
||||
import random
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
@ -1058,9 +1058,6 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
|
||||
if inputs["output_hidden_states"]:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if inputs["use_cache"]:
|
||||
present_key_values = (inputs["encoder_hidden_states"], present_key_values)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns
|
||||
else:
|
||||
@ -1485,43 +1482,23 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past,
|
||||
attention_mask,
|
||||
past=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
use_cache=None,
|
||||
**kwargs,
|
||||
) -> Dict:
|
||||
assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}"
|
||||
if len(past) == 1:
|
||||
assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}"
|
||||
encoder_outputs = TFBaseModelOutput(last_hidden_state=past[0])
|
||||
past_key_values = None
|
||||
else:
|
||||
assert (
|
||||
len(past) == 2
|
||||
), "`past` has to be of length 2 with the encoder_outputs at the first position and past_key_values at the second position."
|
||||
encoder_outputs, past_key_values = past
|
||||
if isinstance(encoder_outputs, tuple):
|
||||
assert isinstance(
|
||||
encoder_outputs[0], tf.Tensor
|
||||
), f"`encoder_outputs[0]` has to be of type `tf.Tensor`, but is {type(encoder_outputs[0])}"
|
||||
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs[0])
|
||||
elif isinstance(encoder_outputs, tf.Tensor):
|
||||
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs)
|
||||
assert (
|
||||
past_key_values
|
||||
), f"decoder cached states must be truthy. got {past_key_values} from the 2nd element of past"
|
||||
encoder_outputs=None,
|
||||
**kwargs
|
||||
):
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
|
||||
assert isinstance(
|
||||
encoder_outputs, TFBaseModelOutput
|
||||
), f"encoder_outputs should be a TFBaseModelOutput, Instead got {type(encoder_outputs)}."
|
||||
return {
|
||||
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past_key_values,
|
||||
"past_key_values": past,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
@ -1536,15 +1513,10 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
|
||||
@staticmethod
|
||||
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration._reorder_cache
|
||||
def _reorder_cache(past, beam_idx):
|
||||
if len(past) == 1:
|
||||
return past
|
||||
|
||||
past_key_values = past[1]
|
||||
|
||||
reordered_past = ()
|
||||
for layer_past_key_values in past_key_values:
|
||||
for layer_past in past:
|
||||
# cached cross_attention states don't have to be reordered -> they are always the same
|
||||
reordered_past += (
|
||||
tuple(tf.gather(layer_past_key_value, beam_idx) for layer_past_key_value in layer_past_key_values[:2])
|
||||
+ layer_past_key_values[2:],
|
||||
tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:],
|
||||
)
|
||||
return (past[0], reordered_past)
|
||||
return reordered_past
|
||||
|
@ -16,14 +16,13 @@
|
||||
"""TFRAG model implementation."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...file_utils import ModelOutput, add_start_docstrings_to_model_forward, replace_return_docstrings
|
||||
from ...modeling_tf_outputs import TFBaseModelOutput
|
||||
from ...modeling_tf_utils import TFCausalLanguageModelingLoss, TFPreTrainedModel, input_processing, shape_list
|
||||
from ...utils import logging
|
||||
from .configuration_rag import RagConfig
|
||||
@ -788,42 +787,28 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
|
||||
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/master/src/transformers/modeling_tf_bart.py
|
||||
def prepare_inputs_for_generation(
|
||||
self, decoder_input_ids, past, attention_mask, use_cache, doc_scores, n_docs=None, **kwargs
|
||||
) -> Dict:
|
||||
assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}"
|
||||
|
||||
if len(past) == 1:
|
||||
assert isinstance(past[0], tf.Tensor)
|
||||
encoder_outputs = TFBaseModelOutput(last_hidden_state=past[0])
|
||||
decoder_cached_states = None
|
||||
else:
|
||||
assert len(past) == 2
|
||||
# Note: encoder_outputs is never changed by Bart as a generator
|
||||
encoder_outputs, decoder_cached_states = past
|
||||
|
||||
if isinstance(encoder_outputs, tuple):
|
||||
assert isinstance(encoder_outputs[0], tf.Tensor)
|
||||
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs[0])
|
||||
elif isinstance(encoder_outputs, tf.Tensor):
|
||||
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs)
|
||||
|
||||
assert (
|
||||
decoder_cached_states
|
||||
), f"decoder cached states must be truthy. got {decoder_cached_states} from the 2nd element of past"
|
||||
# if past is defined cut decoder_input_ids to last token
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past=None,
|
||||
attention_mask=None,
|
||||
use_cache=None,
|
||||
encoder_outputs=None,
|
||||
doc_scores=None,
|
||||
n_docs=None,
|
||||
**kwargs
|
||||
):
|
||||
if past is not None:
|
||||
# if past is defined use only last decoder_input_ids
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
|
||||
assert isinstance(
|
||||
encoder_outputs, TFBaseModelOutput
|
||||
), f"encoder_outputs should be a TFBaseModelOutput, Instead got {type(encoder_outputs)}."
|
||||
return {
|
||||
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||
"input_ids": None,
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"doc_scores": doc_scores,
|
||||
"context_attention_mask": attention_mask,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"past_key_values": decoder_cached_states,
|
||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||
"past_key_values": past,
|
||||
"use_cache": use_cache,
|
||||
"do_marginalize": True,
|
||||
"n_docs": n_docs,
|
||||
}
|
||||
@ -844,46 +829,19 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
|
||||
def _reorder_cache(past, beam_idx):
|
||||
"""Reorders cache for generation. BART-inspired but we need to take care of the extra dimension for docs"""
|
||||
|
||||
def tf_index_select(input_, dim, indices):
|
||||
"""
|
||||
Input:
|
||||
input_(tensor): input tensor dim(int): dimension indices(list): selected indices list
|
||||
Output:
|
||||
mimic of torch_tensor.index_select(dim, indices)
|
||||
|
||||
credit:
|
||||
https://stackoverflow.com/questions/58464790/is-there-an-equivalent-function-of-pytorch-named-index-select-in-tensorflow
|
||||
"""
|
||||
shape = shape_list(input_)
|
||||
if dim == -1:
|
||||
dim = len(shape) - 1
|
||||
shape[dim] = 1
|
||||
|
||||
tmp = []
|
||||
for idx in indices:
|
||||
begin = [0] * len(shape)
|
||||
begin[dim] = idx
|
||||
tmp.append(tf.slice(input_, begin, shape))
|
||||
res = tf.concat(tmp, axis=dim)
|
||||
|
||||
return res
|
||||
|
||||
def _reorder_stacked(hidden_states, new_order=beam_idx):
|
||||
def _reorder_stacked(hidden_states, new_order):
|
||||
n_docs = hidden_states.shape[0] // new_order.shape[0]
|
||||
hidden_states = tf.reshape(hidden_states, (-1, n_docs, *hidden_states.shape[1:]))
|
||||
hidden_states = tf_index_select(hidden_states, 0, new_order)
|
||||
return tf.reshape(hidden_states, (-1, *hidden_states.shape[2:]))
|
||||
|
||||
if len(past) == 1:
|
||||
return past
|
||||
|
||||
past_key_values = past[1]
|
||||
hidden_states = tf.gather(hidden_states, new_order, axis=0)
|
||||
result = tf.reshape(hidden_states, (-1, *hidden_states.shape[2:]))
|
||||
return result
|
||||
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
for layer_past in past:
|
||||
# get the correct batch idx from decoder layer's batch dim for cross and self-attn
|
||||
reordered_past += (tuple(_reorder_stacked(past_state, beam_idx) for past_state in layer_past),)
|
||||
|
||||
return (past[0], reordered_past)
|
||||
return reordered_past
|
||||
|
||||
def marginalize(self, seq_logits, doc_scores, n_docs=None):
|
||||
n_docs = n_docs if n_docs is not None else self.config.n_docs
|
||||
@ -1268,14 +1226,6 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
if return_dict_in_generate:
|
||||
# TODO(Patrick): `encoder_outputs`, `past` hack.
|
||||
# Remove after cleaning encoder-decoder outputs
|
||||
if output_attentions:
|
||||
model_kwargs["encoder_attentions"] = encoder_outputs.attentions
|
||||
if output_hidden_states:
|
||||
model_kwargs["encoder_hidden_states"] = encoder_outputs.hidden_states
|
||||
|
||||
decoder_input_ids = tf.fill(
|
||||
(batch_size * num_beams, 1),
|
||||
tf.cast(decoder_start_token_id, tf.int32),
|
||||
@ -1366,10 +1316,6 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
|
||||
model_kwargs.pop("output_attentions", None)
|
||||
model_kwargs.pop("output_scores", None)
|
||||
|
||||
# TODO(Patrick): `encoder_outputs`, `past` hack.
|
||||
# Remove after cleaning encoder-decoder outputs
|
||||
model_kwargs["past"] = encoder_outputs
|
||||
|
||||
return self.greedy_search(
|
||||
input_ids=decoder_input_ids,
|
||||
max_length=max_length,
|
||||
|
@ -1176,17 +1176,17 @@ class TFRemBertForCausalLM(TFRemBertPreTrainedModel, TFCausalLanguageModelingLos
|
||||
return self.mlm.predictions
|
||||
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel.prepare_inputs_for_generation
|
||||
def prepare_inputs_for_generation(self, inputs, past=None, attention_mask=None, **model_kwargs):
|
||||
# cut decoder_input_ids if past is used
|
||||
if past:
|
||||
inputs = tf.expand_dims(inputs[:, -1], -1)
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
|
||||
input_shape = input_ids.shape
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
if attention_mask is None:
|
||||
attention_mask = tf.ones(input_shape)
|
||||
|
||||
return {
|
||||
"input_ids": inputs,
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": past,
|
||||
"use_cache": model_kwargs["use_cache"],
|
||||
}
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
input_ids = input_ids[:, -1:]
|
||||
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
|
||||
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_TOKENIZER_FOR_DOC,
|
||||
@ -1309,6 +1309,14 @@ class TFRemBertForCausalLM(TFRemBertPreTrainedModel, TFCausalLanguageModelingLos
|
||||
logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel._reorder_cache
|
||||
def _reorder_cache(past, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past:
|
||||
reordered_past += (tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past),)
|
||||
return reordered_past
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
|
@ -1209,17 +1209,17 @@ class TFRobertaForCausalLM(TFRobertaPreTrainedModel, TFCausalLanguageModelingLos
|
||||
return self.name + "/" + self.lm_head.name
|
||||
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel.prepare_inputs_for_generation
|
||||
def prepare_inputs_for_generation(self, inputs, past=None, attention_mask=None, **model_kwargs):
|
||||
# cut decoder_input_ids if past is used
|
||||
if past:
|
||||
inputs = tf.expand_dims(inputs[:, -1], -1)
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
|
||||
input_shape = input_ids.shape
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
if attention_mask is None:
|
||||
attention_mask = tf.ones(input_shape)
|
||||
|
||||
return {
|
||||
"input_ids": inputs,
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": past,
|
||||
"use_cache": model_kwargs["use_cache"],
|
||||
}
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
input_ids = input_ids[:, -1:]
|
||||
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
|
||||
|
||||
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
@ -1344,6 +1344,14 @@ class TFRobertaForCausalLM(TFRobertaPreTrainedModel, TFCausalLanguageModelingLos
|
||||
logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel._reorder_cache
|
||||
def _reorder_cache(past, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past:
|
||||
reordered_past += (tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past),)
|
||||
return reordered_past
|
||||
|
||||
|
||||
class TFRobertaClassificationHead(tf.keras.layers.Layer):
|
||||
"""Head for sentence-level classification tasks."""
|
||||
|
@ -1139,7 +1139,7 @@ class TFSpeech2TextDecoder(tf.keras.layers.Layer):
|
||||
if inputs["output_hidden_states"]:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = (inputs["encoder_hidden_states"], next_decoder_cache) if use_cache else None
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
return hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attns
|
||||
@ -1571,26 +1571,17 @@ class TFSpeech2TextForConditionalGeneration(TFSpeech2TextPreTrainedModel, TFCaus
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
use_cache=None,
|
||||
encoder_outputs=None,
|
||||
**kwargs
|
||||
):
|
||||
if past is not None and len(past) <= 2:
|
||||
if not isinstance(past[0], tf.Tensor):
|
||||
raise ValueError(f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}")
|
||||
encoder_outputs = TFBaseModelOutput(last_hidden_state=past[0])
|
||||
if len(past) == 1:
|
||||
past_key_values = None
|
||||
else:
|
||||
past_key_values = past[1]
|
||||
if not past_key_values:
|
||||
raise ValueError(f"decoder cached states must be truthy, got {past_key_values}")
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
else:
|
||||
raise ValueError(f"`past` must be an iterable with length 1 or 2, got {past}")
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
|
||||
return {
|
||||
"input_features": None, # needs to be passed to make Keras.layer.__call__ happy
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past_key_values,
|
||||
"past_key_values": past,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
@ -1601,15 +1592,7 @@ class TFSpeech2TextForConditionalGeneration(TFSpeech2TextPreTrainedModel, TFCaus
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past, beam_idx):
|
||||
if len(past) == 1:
|
||||
return past
|
||||
|
||||
past_key_values = past[1]
|
||||
|
||||
reordered_past = ()
|
||||
for layer_past_key_values in past_key_values:
|
||||
reordered_past += (
|
||||
tuple(tf.gather(layer_past_key_value, beam_idx) for layer_past_key_value in layer_past_key_values[:2])
|
||||
+ layer_past_key_values[2:],
|
||||
)
|
||||
return (past[0], reordered_past)
|
||||
for layer_past in past:
|
||||
reordered_past += (tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past),)
|
||||
return reordered_past
|
||||
|
@ -1256,15 +1256,13 @@ class TFT5Model(TFT5PreTrainedModel):
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
past = decoder_outputs[1] if inputs["use_cache"] else None
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
past = (inputs["encoder_outputs"], decoder_outputs[1]) if inputs["use_cache"] else None
|
||||
if past is not None:
|
||||
decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:]
|
||||
return decoder_outputs + inputs["encoder_outputs"]
|
||||
|
||||
past = (inputs["encoder_outputs"].to_tuple(), decoder_outputs[1]) if inputs["use_cache"] else None
|
||||
|
||||
return TFSeq2SeqModelOutput(
|
||||
last_hidden_state=decoder_outputs.last_hidden_state,
|
||||
past_key_values=past,
|
||||
@ -1483,8 +1481,8 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
||||
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], logits)
|
||||
|
||||
past = decoder_outputs[1] if inputs["use_cache"] else None
|
||||
if not inputs["return_dict"]:
|
||||
past = (inputs["encoder_outputs"], decoder_outputs[1]) if inputs["use_cache"] else None
|
||||
if past is not None:
|
||||
decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:]
|
||||
output = (logits,) + decoder_outputs[1:] + inputs["encoder_outputs"]
|
||||
@ -1509,8 +1507,6 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
||||
attentions=attentions,
|
||||
)
|
||||
|
||||
past = (inputs["encoder_outputs"].to_tuple(), decoder_outputs[1]) if inputs["use_cache"] else None
|
||||
|
||||
return TFSeq2SeqLMOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
@ -1544,65 +1540,57 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
inputs,
|
||||
past,
|
||||
attention_mask,
|
||||
input_ids,
|
||||
past=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
use_cache=None,
|
||||
**kwargs,
|
||||
encoder_outputs=None,
|
||||
**kwargs
|
||||
):
|
||||
assert past is not None, "past has to be defined for encoder_outputs"
|
||||
|
||||
# first step
|
||||
if len(past) < 2:
|
||||
encoder_outputs, past_key_values = past, None
|
||||
else:
|
||||
encoder_outputs, past_key_values = past[0], past[1]
|
||||
if "encoder_hidden_states" in kwargs:
|
||||
encoder_outputs = (*encoder_outputs, kwargs["encoder_hidden_states"])
|
||||
if "encoder_attentions" in kwargs:
|
||||
encoder_outputs = (*encoder_outputs, kwargs["encoder_attentions"])
|
||||
|
||||
# cut decoder_input_ids if past is used
|
||||
if past_key_values is not None:
|
||||
inputs = inputs[:, -1:]
|
||||
if past is not None:
|
||||
input_ids = input_ids[:, -1:]
|
||||
|
||||
return {
|
||||
"input_ids": None, # inputs don't have to be defined, but still need to be passed to make Keras.layer.__call__ happy
|
||||
"decoder_input_ids": inputs, # inputs are the decoder_input_ids
|
||||
"past_key_values": past_key_values,
|
||||
"input_ids": None, # needs to be passed to make Keras.layer.__call__ happy
|
||||
"decoder_input_ids": input_ids,
|
||||
"past_key_values": past,
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"use_cache": use_cache,
|
||||
}
|
||||
|
||||
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
|
||||
return self._shift_right(labels)
|
||||
|
||||
def _reorder_cache(self, past, beam_idx) -> Tuple:
|
||||
def _reorder_cache(self, past, beam_idx):
|
||||
# if decoder past is not included in output
|
||||
# speedy decoding is disabled and no need to reorder
|
||||
|
||||
if len(past) < 2:
|
||||
if past is None:
|
||||
logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
|
||||
return past
|
||||
|
||||
decoder_past = past[1]
|
||||
past = (past[0],)
|
||||
reordered_decoder_past = ()
|
||||
|
||||
for layer_past_states in decoder_past:
|
||||
for layer_past_states in past:
|
||||
# get the correct batch idx from layer past batch dim
|
||||
# batch dim of `past` is at 2nd position
|
||||
reordered_layer_past_states = ()
|
||||
for layer_past_state in layer_past_states:
|
||||
# need to set correct `past` for each of the four key / value states
|
||||
reordered_layer_past_states = reordered_layer_past_states + (tf.gather(layer_past_state, beam_idx),)
|
||||
reordered_layer_past_states = reordered_layer_past_states + (
|
||||
tf.gather(layer_past_state, beam_idx, axis=0),
|
||||
)
|
||||
|
||||
assert shape_list(reordered_layer_past_states[0]) == shape_list(layer_past_states[0])
|
||||
assert reordered_layer_past_states[0].shape == layer_past_states[0].shape
|
||||
assert len(reordered_layer_past_states) == len(layer_past_states)
|
||||
|
||||
reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
|
||||
return past + (reordered_decoder_past,)
|
||||
return reordered_decoder_past
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
|
@ -1058,15 +1058,22 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
|
||||
attentions=attns,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, inputs, past, **model_kwargs):
|
||||
inputs = {"input_ids": inputs}
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, **model_kwargs):
|
||||
inputs = {}
|
||||
|
||||
# if past is defined in model kwargs then use it for faster decoding
|
||||
if past:
|
||||
inputs["mems"] = past
|
||||
inputs["input_ids"] = tf.expand_dims(input_ids[:, -1], axis=-1)
|
||||
else:
|
||||
inputs["input_ids"] = input_ids
|
||||
|
||||
return inputs
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(mems: List[tf.Tensor], beam_idx: tf.Tensor) -> List[tf.Tensor]:
|
||||
return [tf.gather(layer_past, beam_idx, axis=1) for layer_past in mems]
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
|
@ -722,45 +722,22 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
|
||||
cross_attentions=cross_attns,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, decoder_input_ids, past, use_cache=None, **kwargs):
|
||||
if past is None or len(past) not in {1, 2}:
|
||||
raise ValueError(f"past has to be an iterable of length 1,2 got {past}")
|
||||
|
||||
if len(past) == 1:
|
||||
if not isinstance(past[0], tf.Tensor):
|
||||
raise ValueError(f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}")
|
||||
encoder_outputs = TFBaseModelOutput(last_hidden_state=past[0])
|
||||
past_key_values = None
|
||||
else:
|
||||
if len(past) != 2:
|
||||
raise ValueError(
|
||||
"`past` has to be of length 2 with the encoder_outputs at the first position and past_key_values at the second position."
|
||||
)
|
||||
encoder_outputs, past_key_values = past
|
||||
if isinstance(encoder_outputs, tuple):
|
||||
if not isinstance(encoder_outputs[0], tf.Tensor):
|
||||
raise ValueError(
|
||||
f"`encoder_outputs[0]` has to be of type `tf.Tensor`, but is {type(encoder_outputs[0])}"
|
||||
)
|
||||
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs[0])
|
||||
elif isinstance(encoder_outputs, tf.Tensor):
|
||||
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs)
|
||||
if not past_key_values:
|
||||
raise ValueError(
|
||||
f"decoder cached states must be truthy. got {past_key_values} from the 2nd element of past"
|
||||
)
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
|
||||
if not isinstance(encoder_outputs, TFBaseModelOutput):
|
||||
raise ValueError(f"encoder_outputs should be a TFBaseModelOutput, Instead got {type(encoder_outputs)}.")
|
||||
|
||||
return {
|
||||
"pixel_values": None, # encoder_outputs is defined. pixel_values not needed
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past_key_values,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||
def prepare_inputs_for_generation(
|
||||
self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
|
||||
):
|
||||
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past=past)
|
||||
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
|
||||
input_dict = {
|
||||
"pixel_values": None, # needs to be passed to make Keras.layer.__call__ happy
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"decoder_input_ids": decoder_inputs["input_ids"],
|
||||
# TODO (joao): the `TFBaseModelOutput` wrapper should not be needed after the generate refactor is complete
|
||||
"encoder_outputs": TFBaseModelOutput(last_hidden_state=encoder_outputs[0]),
|
||||
"past_key_values": decoder_inputs["past_key_values"],
|
||||
"use_cache": use_cache,
|
||||
}
|
||||
return input_dict
|
||||
|
||||
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
|
||||
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
|
||||
@ -773,9 +750,4 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
|
||||
|
||||
def _reorder_cache(self, past, beam_idx):
|
||||
# apply decoder cache reordering here
|
||||
if len(past) == 1:
|
||||
return past
|
||||
|
||||
encoder_outputs, past_key_values = past
|
||||
|
||||
return (encoder_outputs, self.decoder._reorder_cache(past_key_values, beam_idx))
|
||||
return self.decoder._reorder_cache(past, beam_idx)
|
||||
|
@ -1246,17 +1246,17 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
|
||||
return self.name + "/" + self.lm_loss.name
|
||||
|
||||
def prepare_inputs_for_generation(self, inputs, past, use_mems=None, **kwargs):
|
||||
def prepare_inputs_for_generation(self, inputs, past=None, use_mems=None, **kwargs):
|
||||
# Add dummy token at the end (no attention on this one)
|
||||
|
||||
effective_batch_size = inputs.shape[0]
|
||||
dummy_token = tf.zeros((effective_batch_size, 1), dtype=inputs.dtype)
|
||||
|
||||
# At every pass, the attention values for the new token and the two last generated tokens
|
||||
# are computed, the rest is reloaded from the `past` cache. A purely auto-regressive model would have
|
||||
# offset = 1; offset = 2 seems to have slightly better computation.
|
||||
offset = 2
|
||||
|
||||
effective_batch_size = inputs.shape[0]
|
||||
dummy_token = tf.zeros((effective_batch_size, 1), dtype=inputs.dtype)
|
||||
|
||||
if past:
|
||||
inputs = tf.concat([inputs[:, -offset:], dummy_token], axis=1)
|
||||
else:
|
||||
@ -1277,7 +1277,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
"input_ids": inputs,
|
||||
"perm_mask": perm_mask,
|
||||
"target_mapping": target_mapping,
|
||||
"use_mems": kwargs.get("use_mems"),
|
||||
"use_mems": use_mems,
|
||||
}
|
||||
|
||||
# if past is defined in model kwargs then use it for faster decoding
|
||||
|
@ -1777,7 +1777,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte
|
||||
|
||||
{% else %}
|
||||
import random
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
@ -2736,9 +2736,6 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
|
||||
if inputs["output_hidden_states"]:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if inputs["use_cache"]:
|
||||
present_key_values = (inputs["encoder_hidden_states"], present_key_values)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns
|
||||
else:
|
||||
@ -3186,43 +3183,23 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past,
|
||||
attention_mask,
|
||||
past=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
use_cache=False,
|
||||
use_cache=None,
|
||||
encoder_outputs=None,
|
||||
**kwargs
|
||||
) -> Dict:
|
||||
assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}"
|
||||
if len(past) == 1:
|
||||
assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}"
|
||||
encoder_outputs = TFBaseModelOutput(last_hidden_state=past[0])
|
||||
past_key_values = None
|
||||
else:
|
||||
assert (
|
||||
len(past) == 2
|
||||
), "`past` has to be of length 2 with the encoder_outputs at the first position and past_key_values at the second position."
|
||||
encoder_outputs, past_key_values = past
|
||||
if isinstance(encoder_outputs, tuple):
|
||||
assert isinstance(
|
||||
encoder_outputs[0], tf.Tensor
|
||||
), f"`encoder_outputs[0]` has to be of type `tf.Tensor`, but is {type(encoder_outputs[0])}"
|
||||
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs[0])
|
||||
elif isinstance(encoder_outputs, tf.Tensor):
|
||||
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs)
|
||||
assert (
|
||||
past_key_values
|
||||
), f"decoder cached states must be truthy. got {past_key_values} from the 2nd element of past"
|
||||
):
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
|
||||
assert isinstance(
|
||||
encoder_outputs, TFBaseModelOutput
|
||||
), f"encoder_outputs should be a TFBaseModelOutput, Instead got {type(encoder_outputs)}."
|
||||
return {
|
||||
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||
"input_ids": None, # needs to be passed to make Keras.layer.__call__ happy
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past_key_values,
|
||||
"past_key_values": past,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
@ -3233,17 +3210,10 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past, beam_idx):
|
||||
if len(past) == 1:
|
||||
return past
|
||||
|
||||
past_key_values = past[1]
|
||||
|
||||
reordered_past = ()
|
||||
for layer_past_key_values in past_key_values:
|
||||
reordered_past += (
|
||||
tuple(tf.gather(layer_past_key_value, beam_idx) for layer_past_key_value in layer_past_key_values[:2]) + layer_past_key_values[2:],
|
||||
)
|
||||
return (past[0], reordered_past)
|
||||
for layer_past in past:
|
||||
reordered_past += (tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past),)
|
||||
return reordered_past
|
||||
|
||||
def hf_compute_loss(self, labels, logits):
|
||||
"""CrossEntropyLoss that ignores pad tokens"""
|
||||
|
@ -802,7 +802,6 @@ class TF{{cookiecutter.camelcase_modelname}}ModelTester:
|
||||
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
|
||||
|
||||
output, past_key_values = outputs.to_tuple()
|
||||
past_key_values = past_key_values[1]
|
||||
|
||||
# create hypothetical next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
|
||||
|
@ -116,7 +116,6 @@ class TFBartModelTester:
|
||||
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
|
||||
|
||||
output, past_key_values = outputs.to_tuple()
|
||||
past_key_values = past_key_values[1]
|
||||
|
||||
# create hypothetical next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
|
||||
|
@ -114,7 +114,6 @@ class TFBlenderbotModelTester:
|
||||
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
|
||||
|
||||
output, past_key_values = outputs.to_tuple()
|
||||
past_key_values = past_key_values[1]
|
||||
|
||||
# create hypothetical next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
|
||||
|
@ -114,7 +114,6 @@ class TFBlenderbotSmallModelTester:
|
||||
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
|
||||
|
||||
output, past_key_values = outputs.to_tuple()
|
||||
past_key_values = past_key_values[1]
|
||||
|
||||
# create hypothetical next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
|
||||
|
@ -133,7 +133,6 @@ class TFLEDModelTester:
|
||||
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
|
||||
|
||||
output, past_key_values = outputs.to_tuple()
|
||||
past_key_values = past_key_values[1]
|
||||
|
||||
# create hypothetical next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
|
||||
|
@ -116,7 +116,6 @@ class TFMarianModelTester:
|
||||
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
|
||||
|
||||
output, past_key_values = outputs.to_tuple()
|
||||
past_key_values = past_key_values[1]
|
||||
|
||||
# create hypothetical next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
|
||||
|
@ -114,7 +114,6 @@ class TFPegasusModelTester:
|
||||
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
|
||||
|
||||
output, past_key_values = outputs.to_tuple()
|
||||
past_key_values = past_key_values[1]
|
||||
|
||||
# create hypothetical next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
|
||||
|
@ -182,7 +182,7 @@ class TFSpeech2TextModelTester:
|
||||
# first forward pass
|
||||
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
|
||||
|
||||
_, (_, past_key_values) = outputs.to_tuple()
|
||||
_, past_key_values = outputs.to_tuple()
|
||||
|
||||
# create hypothetical multiple next token and extent to next_input_ids
|
||||
next_tokens = tf.math.maximum(ids_tensor((self.batch_size, 3), config.vocab_size), 2)
|
||||
|
@ -98,13 +98,10 @@ class TFT5ModelTester:
|
||||
encoder_output = result.encoder_last_hidden_state
|
||||
self.parent.assertListEqual(list(encoder_output.shape), [self.batch_size, self.seq_length, self.hidden_size])
|
||||
self.parent.assertListEqual(list(decoder_output.shape), [self.batch_size, self.seq_length, self.hidden_size])
|
||||
self.parent.assertEqual(len(decoder_past), 2)
|
||||
# decoder_past[0] should correspond to encoder output
|
||||
self.parent.assertTrue(tf.reduce_all(tf.math.equal(decoder_past[0][0], encoder_output)))
|
||||
# There should be `num_layers` key value embeddings stored in decoder_past[1]
|
||||
self.parent.assertEqual(len(decoder_past[1]), config.num_layers)
|
||||
self.parent.assertEqual(len(decoder_past), config.num_layers)
|
||||
# There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past[1] tuple
|
||||
self.parent.assertEqual(len(decoder_past[1][0]), 4)
|
||||
self.parent.assertEqual(len(decoder_past[0]), 4)
|
||||
|
||||
def create_and_check_t5_with_lm_head(self, config, input_ids, input_mask, token_labels):
|
||||
model = TFT5ForConditionalGeneration(config=config)
|
||||
|
Loading…
Reference in New Issue
Block a user