Generate: delete unused TF _reorder_cache (#20964)

This commit is contained in:
Joao Gante 2023-01-03 10:54:56 +00:00 committed by GitHub
parent a3e8d3cb1c
commit 4fd89e4978
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 0 additions and 207 deletions

View File

@ -449,10 +449,6 @@ class TFGenerationMixin:
supports_xla_generation = True
@staticmethod
def _reorder_cache(past, beam_idx):
return tuple(tf.gather(layer_past, beam_idx, axis=1) for layer_past in past)
def adjust_logits_during_generation(
self, logits, cur_len, max_length, forced_bos_token_id, forced_eos_token_id, **kwargs
):

View File

@ -1475,16 +1475,6 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
@staticmethod
def _reorder_cache(past, beam_idx):
reordered_past = ()
for layer_past in past:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += (
tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:],
)
return reordered_past
@add_start_docstrings(
"""

View File

@ -1508,13 +1508,6 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns
)
@staticmethod
def _reorder_cache(past, beam_idx):
reordered_past = ()
for layer_past in past:
reordered_past += (tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past),)
return reordered_past
@add_start_docstrings(
"""Bert Model with a `next sentence prediction (classification)` head on top.""",

View File

@ -1473,14 +1473,3 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal
"cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
}
@staticmethod
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration._reorder_cache
def _reorder_cache(past, beam_idx):
reordered_past = ()
for layer_past in past:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += (
tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:],
)
return reordered_past

View File

@ -1453,14 +1453,3 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
"cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
}
@staticmethod
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration._reorder_cache
def _reorder_cache(past, beam_idx):
reordered_past = ()
for layer_past in past:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += (
tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:],
)
return reordered_past

View File

@ -1726,11 +1726,3 @@ class TFCamembertForCausalLM(TFCamembertPreTrainedModel, TFCausalLanguageModelin
return TFCausalLMOutputWithCrossAttentions(
logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns
)
@staticmethod
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel._reorder_cache
def _reorder_cache(past, beam_idx):
reordered_past = ()
for layer_past in past:
reordered_past += (tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past),)
return reordered_past

View File

@ -722,12 +722,6 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
return TFCausalLMOutputWithPast(logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns)
@staticmethod
def _reorder_cache(past: Tuple[Tuple[tf.Tensor]], beam_idx: tf.Tensor) -> Tuple[Tuple[tf.Tensor]]:
return tuple(
tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past) for layer_past in past
)
@add_start_docstrings(
"""

View File

@ -720,7 +720,3 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
" respective methods of the wrapped objects (model.encoder.resize_token_embeddings(...) or"
" model.decoder.resize_token_embeddings(...))"
)
def _reorder_cache(self, past, beam_idx):
# apply decoder cache reordering here
return self.decoder._reorder_cache(past, beam_idx)

View File

@ -2538,16 +2538,6 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
@staticmethod
def _reorder_cache(past, beam_idx):
reordered_past = ()
for layer_past in past:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += (
tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:],
)
return reordered_past
def hf_compute_loss(self, labels, logits):
"""CrossEntropyLoss that ignores pad tokens"""
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(

View File

@ -1494,17 +1494,6 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
@staticmethod
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration._reorder_cache
def _reorder_cache(past, beam_idx):
reordered_past = ()
for layer_past in past:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += (
tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:],
)
return reordered_past
def adjust_logits_during_generation(
self, logits, cur_len, max_length, forced_bos_token_id, forced_eos_token_id, **kwargs
):

View File

@ -1490,14 +1490,3 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id)
@staticmethod
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration._reorder_cache
def _reorder_cache(past, beam_idx):
reordered_past = ()
for layer_past in past:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += (
tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:],
)
return reordered_past

View File

@ -1503,14 +1503,3 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
@staticmethod
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration._reorder_cache
def _reorder_cache(past, beam_idx):
reordered_past = ()
for layer_past in past:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += (
tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:],
)
return reordered_past

View File

@ -799,24 +799,6 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
def question_encoder(self):
return self.rag.question_encoder
@staticmethod
def _reorder_cache(past, beam_idx):
"""Reorders cache for generation. BART-inspired but we need to take care of the extra dimension for docs"""
def _reorder_stacked(hidden_states, new_order):
n_docs = hidden_states.shape[0] // new_order.shape[0]
hidden_states = tf.reshape(hidden_states, (-1, n_docs, *hidden_states.shape[1:]))
hidden_states = tf.gather(hidden_states, new_order, axis=0)
result = tf.reshape(hidden_states, (-1, *hidden_states.shape[2:]))
return result
reordered_past = ()
for layer_past in past:
# get the correct batch idx from decoder layer's batch dim for cross and self-attn
reordered_past += (tuple(_reorder_stacked(past_state, beam_idx) for past_state in layer_past),)
return reordered_past
@staticmethod
def _gather_beams(nested, beam_indices, batch_axis=0):
"""

View File

@ -1244,14 +1244,6 @@ class TFRemBertForCausalLM(TFRemBertPreTrainedModel, TFCausalLanguageModelingLos
logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns
)
@staticmethod
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel._reorder_cache
def _reorder_cache(past, beam_idx):
reordered_past = ()
for layer_past in past:
reordered_past += (tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past),)
return reordered_past
@add_start_docstrings(
"""

View File

@ -1286,14 +1286,6 @@ class TFRobertaForCausalLM(TFRobertaPreTrainedModel, TFCausalLanguageModelingLos
logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns
)
@staticmethod
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel._reorder_cache
def _reorder_cache(past, beam_idx):
reordered_past = ()
for layer_past in past:
reordered_past += (tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past),)
return reordered_past
class TFRobertaClassificationHead(tf.keras.layers.Layer):
"""Head for sentence-level classification tasks."""

View File

@ -1301,14 +1301,6 @@ class TFRobertaPreLayerNormForCausalLM(TFRobertaPreLayerNormPreTrainedModel, TFC
logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns
)
@staticmethod
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel._reorder_cache
def _reorder_cache(past, beam_idx):
reordered_past = ()
for layer_past in past:
reordered_past += (tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past),)
return reordered_past
# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaClassificationHead with Roberta->RobertaPreLayerNorm
class TFRobertaPreLayerNormClassificationHead(tf.keras.layers.Layer):

View File

@ -1501,10 +1501,3 @@ class TFSpeech2TextForConditionalGeneration(TFSpeech2TextPreTrainedModel, TFCaus
"cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
}
@staticmethod
def _reorder_cache(past, beam_idx):
reordered_past = ()
for layer_past in past:
reordered_past += (tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past),)
return reordered_past

View File

@ -1528,30 +1528,6 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
return self._shift_right(labels)
def _reorder_cache(self, past, beam_idx):
# if decoder past is not included in output
# speedy decoding is disabled and no need to reorder
if past is None:
logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
return past
reordered_decoder_past = ()
for layer_past_states in past:
# get the correct batch idx from layer past batch dim
# batch dim of `past` is at 2nd position
reordered_layer_past_states = ()
for layer_past_state in layer_past_states:
# need to set correct `past` for each of the four key / value states
reordered_layer_past_states = reordered_layer_past_states + (
tf.gather(layer_past_state, beam_idx, axis=0),
)
assert reordered_layer_past_states[0].shape == layer_past_states[0].shape
assert len(reordered_layer_past_states) == len(layer_past_states)
reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
return reordered_decoder_past
@add_start_docstrings(
"The bare T5 Model transformer outputting encoder's raw hidden-stateswithout any specific head on top.",

View File

@ -1039,10 +1039,6 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
return inputs
@staticmethod
def _reorder_cache(mems: List[tf.Tensor], beam_idx: tf.Tensor) -> List[tf.Tensor]:
return [tf.gather(layer_past, beam_idx, axis=1) for layer_past in mems]
@add_start_docstrings(
"""

View File

@ -756,7 +756,3 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
"Resizing the embedding layers via the TFVisionEncoderDecoderModel directly is not supported."
"Please use the respective methods of the wrapped objects (model.decoder.resize_token_embeddings(...))"
)
def _reorder_cache(self, past, beam_idx):
# apply decoder cache reordering here
return self.decoder._reorder_cache(past, beam_idx)

View File

@ -1386,11 +1386,3 @@ class TFWhisperForConditionalGeneration(TFWhisperPreTrainedModel, TFCausalLangua
"decoder_attention_mask": decoder_attention_mask,
"decoder_position_ids": decoder_position_ids,
}
#
@staticmethod
def _reorder_cache(past, beam_idx):
reordered_past = ()
for layer_past in past:
reordered_past += (tuple(tf.gather(past_state, beam_idx) for past_state in layer_past),)
return reordered_past

View File

@ -992,10 +992,3 @@ class TFXGLMForCausalLM(TFXGLMPreTrainedModel, TFCausalLanguageModelingLoss):
attentions=attns,
cross_attentions=cross_attns,
)
@staticmethod
def _reorder_cache(past, beam_idx):
reordered_past = ()
for layer_past in past:
reordered_past += (tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past),)
return reordered_past

View File

@ -3028,13 +3028,6 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
}
@staticmethod
def _reorder_cache(past, beam_idx):
reordered_past = ()
for layer_past in past:
reordered_past += (tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past),)
return reordered_past
def hf_compute_loss(self, labels, logits):
"""CrossEntropyLoss that ignores pad tokens"""
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(