From 4fd89e49788f60b021b4e2c578a1fb12f1e900e4 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 3 Jan 2023 10:54:56 +0000 Subject: [PATCH] Generate: delete unused TF `_reorder_cache` (#20964) --- src/transformers/generation/tf_utils.py | 4 ---- .../models/bart/modeling_tf_bart.py | 10 -------- .../models/bert/modeling_tf_bert.py | 7 ------ .../blenderbot/modeling_tf_blenderbot.py | 11 --------- .../modeling_tf_blenderbot_small.py | 11 --------- .../models/camembert/modeling_tf_camembert.py | 8 ------- .../models/ctrl/modeling_tf_ctrl.py | 6 ----- .../modeling_tf_encoder_decoder.py | 4 ---- .../models/led/modeling_tf_led.py | 10 -------- .../models/marian/modeling_tf_marian.py | 11 --------- .../models/mbart/modeling_tf_mbart.py | 11 --------- .../models/pegasus/modeling_tf_pegasus.py | 11 --------- .../models/rag/modeling_tf_rag.py | 18 -------------- .../models/rembert/modeling_tf_rembert.py | 8 ------- .../models/roberta/modeling_tf_roberta.py | 8 ------- .../modeling_tf_roberta_prelayernorm.py | 8 ------- .../modeling_tf_speech_to_text.py | 7 ------ src/transformers/models/t5/modeling_tf_t5.py | 24 ------------------- .../transfo_xl/modeling_tf_transfo_xl.py | 4 ---- .../modeling_tf_vision_encoder_decoder.py | 4 ---- .../models/whisper/modeling_tf_whisper.py | 8 ------- .../models/xglm/modeling_tf_xglm.py | 7 ------ ...tf_{{cookiecutter.lowercase_modelname}}.py | 7 ------ 23 files changed, 207 deletions(-) diff --git a/src/transformers/generation/tf_utils.py b/src/transformers/generation/tf_utils.py index 0198eefa6dc..64c47a1ea62 100644 --- a/src/transformers/generation/tf_utils.py +++ b/src/transformers/generation/tf_utils.py @@ -449,10 +449,6 @@ class TFGenerationMixin: supports_xla_generation = True - @staticmethod - def _reorder_cache(past, beam_idx): - return tuple(tf.gather(layer_past, beam_idx, axis=1) for layer_past in past) - def adjust_logits_during_generation( self, logits, cur_len, max_length, forced_bos_token_id, forced_eos_token_id, **kwargs ): diff --git a/src/transformers/models/bart/modeling_tf_bart.py b/src/transformers/models/bart/modeling_tf_bart.py index 19204f7e863..6664c65b86a 100644 --- a/src/transformers/models/bart/modeling_tf_bart.py +++ b/src/transformers/models/bart/modeling_tf_bart.py @@ -1475,16 +1475,6 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor): return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) - @staticmethod - def _reorder_cache(past, beam_idx): - reordered_past = () - for layer_past in past: - # cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += ( - tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:], - ) - return reordered_past - @add_start_docstrings( """ diff --git a/src/transformers/models/bert/modeling_tf_bert.py b/src/transformers/models/bert/modeling_tf_bert.py index 3ddc36fc7a5..67714a92c9f 100644 --- a/src/transformers/models/bert/modeling_tf_bert.py +++ b/src/transformers/models/bert/modeling_tf_bert.py @@ -1508,13 +1508,6 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss): logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns ) - @staticmethod - def _reorder_cache(past, beam_idx): - reordered_past = () - for layer_past in past: - reordered_past += (tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past),) - return reordered_past - @add_start_docstrings( """Bert Model with a `next sentence prediction (classification)` head on top.""", diff --git a/src/transformers/models/blenderbot/modeling_tf_blenderbot.py b/src/transformers/models/blenderbot/modeling_tf_blenderbot.py index 0e9d39d7da8..9890efd310d 100644 --- a/src/transformers/models/blenderbot/modeling_tf_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_tf_blenderbot.py @@ -1473,14 +1473,3 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal "cross_attn_head_mask": cross_attn_head_mask, "use_cache": use_cache, # change this to avoid caching (presumably for debugging) } - - @staticmethod - # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration._reorder_cache - def _reorder_cache(past, beam_idx): - reordered_past = () - for layer_past in past: - # cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += ( - tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:], - ) - return reordered_past diff --git a/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py index 6e7a6dd2706..293ff0ccf4f 100644 --- a/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py @@ -1453,14 +1453,3 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel "cross_attn_head_mask": cross_attn_head_mask, "use_cache": use_cache, # change this to avoid caching (presumably for debugging) } - - @staticmethod - # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration._reorder_cache - def _reorder_cache(past, beam_idx): - reordered_past = () - for layer_past in past: - # cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += ( - tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:], - ) - return reordered_past diff --git a/src/transformers/models/camembert/modeling_tf_camembert.py b/src/transformers/models/camembert/modeling_tf_camembert.py index f75b97ed473..0c904511b81 100644 --- a/src/transformers/models/camembert/modeling_tf_camembert.py +++ b/src/transformers/models/camembert/modeling_tf_camembert.py @@ -1726,11 +1726,3 @@ class TFCamembertForCausalLM(TFCamembertPreTrainedModel, TFCausalLanguageModelin return TFCausalLMOutputWithCrossAttentions( logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns ) - - @staticmethod - # Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel._reorder_cache - def _reorder_cache(past, beam_idx): - reordered_past = () - for layer_past in past: - reordered_past += (tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past),) - return reordered_past diff --git a/src/transformers/models/ctrl/modeling_tf_ctrl.py b/src/transformers/models/ctrl/modeling_tf_ctrl.py index 81872b9671c..8b8c2491aa0 100644 --- a/src/transformers/models/ctrl/modeling_tf_ctrl.py +++ b/src/transformers/models/ctrl/modeling_tf_ctrl.py @@ -722,12 +722,6 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss): return TFCausalLMOutputWithPast(logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns) - @staticmethod - def _reorder_cache(past: Tuple[Tuple[tf.Tensor]], beam_idx: tf.Tensor) -> Tuple[Tuple[tf.Tensor]]: - return tuple( - tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past) for layer_past in past - ) - @add_start_docstrings( """ diff --git a/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py index a13d18806e5..c6a8fb0f35c 100644 --- a/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py @@ -720,7 +720,3 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss): " respective methods of the wrapped objects (model.encoder.resize_token_embeddings(...) or" " model.decoder.resize_token_embeddings(...))" ) - - def _reorder_cache(self, past, beam_idx): - # apply decoder cache reordering here - return self.decoder._reorder_cache(past, beam_idx) diff --git a/src/transformers/models/led/modeling_tf_led.py b/src/transformers/models/led/modeling_tf_led.py index 78e98562ac4..d728b283fa3 100644 --- a/src/transformers/models/led/modeling_tf_led.py +++ b/src/transformers/models/led/modeling_tf_led.py @@ -2538,16 +2538,6 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel): def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor): return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) - @staticmethod - def _reorder_cache(past, beam_idx): - reordered_past = () - for layer_past in past: - # cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += ( - tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:], - ) - return reordered_past - def hf_compute_loss(self, labels, logits): """CrossEntropyLoss that ignores pad tokens""" loss_fn = tf.keras.losses.SparseCategoricalCrossentropy( diff --git a/src/transformers/models/marian/modeling_tf_marian.py b/src/transformers/models/marian/modeling_tf_marian.py index a3c8fe62fb0..960ed5aba9b 100644 --- a/src/transformers/models/marian/modeling_tf_marian.py +++ b/src/transformers/models/marian/modeling_tf_marian.py @@ -1494,17 +1494,6 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss): def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor): return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) - @staticmethod - # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration._reorder_cache - def _reorder_cache(past, beam_idx): - reordered_past = () - for layer_past in past: - # cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += ( - tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:], - ) - return reordered_past - def adjust_logits_during_generation( self, logits, cur_len, max_length, forced_bos_token_id, forced_eos_token_id, **kwargs ): diff --git a/src/transformers/models/mbart/modeling_tf_mbart.py b/src/transformers/models/mbart/modeling_tf_mbart.py index 05e7a230762..4c60b42cb7e 100644 --- a/src/transformers/models/mbart/modeling_tf_mbart.py +++ b/src/transformers/models/mbart/modeling_tf_mbart.py @@ -1490,14 +1490,3 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor): return shift_tokens_right(labels, self.config.pad_token_id) - - @staticmethod - # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration._reorder_cache - def _reorder_cache(past, beam_idx): - reordered_past = () - for layer_past in past: - # cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += ( - tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:], - ) - return reordered_past diff --git a/src/transformers/models/pegasus/modeling_tf_pegasus.py b/src/transformers/models/pegasus/modeling_tf_pegasus.py index 3d40ae661f6..17d5f2e39b0 100644 --- a/src/transformers/models/pegasus/modeling_tf_pegasus.py +++ b/src/transformers/models/pegasus/modeling_tf_pegasus.py @@ -1503,14 +1503,3 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor): return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) - - @staticmethod - # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration._reorder_cache - def _reorder_cache(past, beam_idx): - reordered_past = () - for layer_past in past: - # cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += ( - tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:], - ) - return reordered_past diff --git a/src/transformers/models/rag/modeling_tf_rag.py b/src/transformers/models/rag/modeling_tf_rag.py index c336cab4686..feed5f00d7f 100644 --- a/src/transformers/models/rag/modeling_tf_rag.py +++ b/src/transformers/models/rag/modeling_tf_rag.py @@ -799,24 +799,6 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss def question_encoder(self): return self.rag.question_encoder - @staticmethod - def _reorder_cache(past, beam_idx): - """Reorders cache for generation. BART-inspired but we need to take care of the extra dimension for docs""" - - def _reorder_stacked(hidden_states, new_order): - n_docs = hidden_states.shape[0] // new_order.shape[0] - hidden_states = tf.reshape(hidden_states, (-1, n_docs, *hidden_states.shape[1:])) - hidden_states = tf.gather(hidden_states, new_order, axis=0) - result = tf.reshape(hidden_states, (-1, *hidden_states.shape[2:])) - return result - - reordered_past = () - for layer_past in past: - # get the correct batch idx from decoder layer's batch dim for cross and self-attn - reordered_past += (tuple(_reorder_stacked(past_state, beam_idx) for past_state in layer_past),) - - return reordered_past - @staticmethod def _gather_beams(nested, beam_indices, batch_axis=0): """ diff --git a/src/transformers/models/rembert/modeling_tf_rembert.py b/src/transformers/models/rembert/modeling_tf_rembert.py index 3ca26dd670b..26ca303471c 100644 --- a/src/transformers/models/rembert/modeling_tf_rembert.py +++ b/src/transformers/models/rembert/modeling_tf_rembert.py @@ -1244,14 +1244,6 @@ class TFRemBertForCausalLM(TFRemBertPreTrainedModel, TFCausalLanguageModelingLos logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns ) - @staticmethod - # Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel._reorder_cache - def _reorder_cache(past, beam_idx): - reordered_past = () - for layer_past in past: - reordered_past += (tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past),) - return reordered_past - @add_start_docstrings( """ diff --git a/src/transformers/models/roberta/modeling_tf_roberta.py b/src/transformers/models/roberta/modeling_tf_roberta.py index 9488412699d..2d4ae024631 100644 --- a/src/transformers/models/roberta/modeling_tf_roberta.py +++ b/src/transformers/models/roberta/modeling_tf_roberta.py @@ -1286,14 +1286,6 @@ class TFRobertaForCausalLM(TFRobertaPreTrainedModel, TFCausalLanguageModelingLos logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns ) - @staticmethod - # Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel._reorder_cache - def _reorder_cache(past, beam_idx): - reordered_past = () - for layer_past in past: - reordered_past += (tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past),) - return reordered_past - class TFRobertaClassificationHead(tf.keras.layers.Layer): """Head for sentence-level classification tasks.""" diff --git a/src/transformers/models/roberta_prelayernorm/modeling_tf_roberta_prelayernorm.py b/src/transformers/models/roberta_prelayernorm/modeling_tf_roberta_prelayernorm.py index 175573626cb..02f3385ef9a 100644 --- a/src/transformers/models/roberta_prelayernorm/modeling_tf_roberta_prelayernorm.py +++ b/src/transformers/models/roberta_prelayernorm/modeling_tf_roberta_prelayernorm.py @@ -1301,14 +1301,6 @@ class TFRobertaPreLayerNormForCausalLM(TFRobertaPreLayerNormPreTrainedModel, TFC logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns ) - @staticmethod - # Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel._reorder_cache - def _reorder_cache(past, beam_idx): - reordered_past = () - for layer_past in past: - reordered_past += (tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past),) - return reordered_past - # Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaClassificationHead with Roberta->RobertaPreLayerNorm class TFRobertaPreLayerNormClassificationHead(tf.keras.layers.Layer): diff --git a/src/transformers/models/speech_to_text/modeling_tf_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_tf_speech_to_text.py index 377c862a837..c1dbe2e7c6f 100755 --- a/src/transformers/models/speech_to_text/modeling_tf_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_tf_speech_to_text.py @@ -1501,10 +1501,3 @@ class TFSpeech2TextForConditionalGeneration(TFSpeech2TextPreTrainedModel, TFCaus "cross_attn_head_mask": cross_attn_head_mask, "use_cache": use_cache, # change this to avoid caching (presumably for debugging) } - - @staticmethod - def _reorder_cache(past, beam_idx): - reordered_past = () - for layer_past in past: - reordered_past += (tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past),) - return reordered_past diff --git a/src/transformers/models/t5/modeling_tf_t5.py b/src/transformers/models/t5/modeling_tf_t5.py index ff20ff1d17b..15ece6e2397 100644 --- a/src/transformers/models/t5/modeling_tf_t5.py +++ b/src/transformers/models/t5/modeling_tf_t5.py @@ -1528,30 +1528,6 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor): return self._shift_right(labels) - def _reorder_cache(self, past, beam_idx): - # if decoder past is not included in output - # speedy decoding is disabled and no need to reorder - if past is None: - logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") - return past - - reordered_decoder_past = () - for layer_past_states in past: - # get the correct batch idx from layer past batch dim - # batch dim of `past` is at 2nd position - reordered_layer_past_states = () - for layer_past_state in layer_past_states: - # need to set correct `past` for each of the four key / value states - reordered_layer_past_states = reordered_layer_past_states + ( - tf.gather(layer_past_state, beam_idx, axis=0), - ) - - assert reordered_layer_past_states[0].shape == layer_past_states[0].shape - assert len(reordered_layer_past_states) == len(layer_past_states) - - reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) - return reordered_decoder_past - @add_start_docstrings( "The bare T5 Model transformer outputting encoder's raw hidden-stateswithout any specific head on top.", diff --git a/src/transformers/models/transfo_xl/modeling_tf_transfo_xl.py b/src/transformers/models/transfo_xl/modeling_tf_transfo_xl.py index c3da50d8209..ee2a2a1bb2a 100644 --- a/src/transformers/models/transfo_xl/modeling_tf_transfo_xl.py +++ b/src/transformers/models/transfo_xl/modeling_tf_transfo_xl.py @@ -1039,10 +1039,6 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel): return inputs - @staticmethod - def _reorder_cache(mems: List[tf.Tensor], beam_idx: tf.Tensor) -> List[tf.Tensor]: - return [tf.gather(layer_past, beam_idx, axis=1) for layer_past in mems] - @add_start_docstrings( """ diff --git a/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py index 4133dcba67c..16b9d77f953 100644 --- a/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py +++ b/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py @@ -756,7 +756,3 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos "Resizing the embedding layers via the TFVisionEncoderDecoderModel directly is not supported." "Please use the respective methods of the wrapped objects (model.decoder.resize_token_embeddings(...))" ) - - def _reorder_cache(self, past, beam_idx): - # apply decoder cache reordering here - return self.decoder._reorder_cache(past, beam_idx) diff --git a/src/transformers/models/whisper/modeling_tf_whisper.py b/src/transformers/models/whisper/modeling_tf_whisper.py index 3753099d311..e8334c49e6e 100644 --- a/src/transformers/models/whisper/modeling_tf_whisper.py +++ b/src/transformers/models/whisper/modeling_tf_whisper.py @@ -1386,11 +1386,3 @@ class TFWhisperForConditionalGeneration(TFWhisperPreTrainedModel, TFCausalLangua "decoder_attention_mask": decoder_attention_mask, "decoder_position_ids": decoder_position_ids, } - - # - @staticmethod - def _reorder_cache(past, beam_idx): - reordered_past = () - for layer_past in past: - reordered_past += (tuple(tf.gather(past_state, beam_idx) for past_state in layer_past),) - return reordered_past diff --git a/src/transformers/models/xglm/modeling_tf_xglm.py b/src/transformers/models/xglm/modeling_tf_xglm.py index 95e9d228402..655c2d45f24 100644 --- a/src/transformers/models/xglm/modeling_tf_xglm.py +++ b/src/transformers/models/xglm/modeling_tf_xglm.py @@ -992,10 +992,3 @@ class TFXGLMForCausalLM(TFXGLMPreTrainedModel, TFCausalLanguageModelingLoss): attentions=attns, cross_attentions=cross_attns, ) - - @staticmethod - def _reorder_cache(past, beam_idx): - reordered_past = () - for layer_past in past: - reordered_past += (tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past),) - return reordered_past diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py index 95770c82697..1fbc4b6eb59 100644 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py @@ -3028,13 +3028,6 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec "use_cache": use_cache, # change this to avoid caching (presumably for debugging) } - @staticmethod - def _reorder_cache(past, beam_idx): - reordered_past = () - for layer_past in past: - reordered_past += (tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past),) - return reordered_past - def hf_compute_loss(self, labels, logits): """CrossEntropyLoss that ignores pad tokens""" loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(