mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
Generate: delete unused TF _reorder_cache
(#20964)
This commit is contained in:
parent
a3e8d3cb1c
commit
4fd89e4978
@ -449,10 +449,6 @@ class TFGenerationMixin:
|
|||||||
|
|
||||||
supports_xla_generation = True
|
supports_xla_generation = True
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _reorder_cache(past, beam_idx):
|
|
||||||
return tuple(tf.gather(layer_past, beam_idx, axis=1) for layer_past in past)
|
|
||||||
|
|
||||||
def adjust_logits_during_generation(
|
def adjust_logits_during_generation(
|
||||||
self, logits, cur_len, max_length, forced_bos_token_id, forced_eos_token_id, **kwargs
|
self, logits, cur_len, max_length, forced_bos_token_id, forced_eos_token_id, **kwargs
|
||||||
):
|
):
|
||||||
|
@ -1475,16 +1475,6 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
|
|||||||
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
|
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
|
||||||
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
|
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _reorder_cache(past, beam_idx):
|
|
||||||
reordered_past = ()
|
|
||||||
for layer_past in past:
|
|
||||||
# cached cross_attention states don't have to be reordered -> they are always the same
|
|
||||||
reordered_past += (
|
|
||||||
tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:],
|
|
||||||
)
|
|
||||||
return reordered_past
|
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"""
|
"""
|
||||||
|
@ -1508,13 +1508,6 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
|
|||||||
logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns
|
logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _reorder_cache(past, beam_idx):
|
|
||||||
reordered_past = ()
|
|
||||||
for layer_past in past:
|
|
||||||
reordered_past += (tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past),)
|
|
||||||
return reordered_past
|
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"""Bert Model with a `next sentence prediction (classification)` head on top.""",
|
"""Bert Model with a `next sentence prediction (classification)` head on top.""",
|
||||||
|
@ -1473,14 +1473,3 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal
|
|||||||
"cross_attn_head_mask": cross_attn_head_mask,
|
"cross_attn_head_mask": cross_attn_head_mask,
|
||||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||||
}
|
}
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration._reorder_cache
|
|
||||||
def _reorder_cache(past, beam_idx):
|
|
||||||
reordered_past = ()
|
|
||||||
for layer_past in past:
|
|
||||||
# cached cross_attention states don't have to be reordered -> they are always the same
|
|
||||||
reordered_past += (
|
|
||||||
tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:],
|
|
||||||
)
|
|
||||||
return reordered_past
|
|
||||||
|
@ -1453,14 +1453,3 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
|
|||||||
"cross_attn_head_mask": cross_attn_head_mask,
|
"cross_attn_head_mask": cross_attn_head_mask,
|
||||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||||
}
|
}
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration._reorder_cache
|
|
||||||
def _reorder_cache(past, beam_idx):
|
|
||||||
reordered_past = ()
|
|
||||||
for layer_past in past:
|
|
||||||
# cached cross_attention states don't have to be reordered -> they are always the same
|
|
||||||
reordered_past += (
|
|
||||||
tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:],
|
|
||||||
)
|
|
||||||
return reordered_past
|
|
||||||
|
@ -1726,11 +1726,3 @@ class TFCamembertForCausalLM(TFCamembertPreTrainedModel, TFCausalLanguageModelin
|
|||||||
return TFCausalLMOutputWithCrossAttentions(
|
return TFCausalLMOutputWithCrossAttentions(
|
||||||
logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns
|
logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel._reorder_cache
|
|
||||||
def _reorder_cache(past, beam_idx):
|
|
||||||
reordered_past = ()
|
|
||||||
for layer_past in past:
|
|
||||||
reordered_past += (tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past),)
|
|
||||||
return reordered_past
|
|
||||||
|
@ -722,12 +722,6 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
|
|||||||
|
|
||||||
return TFCausalLMOutputWithPast(logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns)
|
return TFCausalLMOutputWithPast(logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _reorder_cache(past: Tuple[Tuple[tf.Tensor]], beam_idx: tf.Tensor) -> Tuple[Tuple[tf.Tensor]]:
|
|
||||||
return tuple(
|
|
||||||
tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past) for layer_past in past
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"""
|
"""
|
||||||
|
@ -720,7 +720,3 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
|
|||||||
" respective methods of the wrapped objects (model.encoder.resize_token_embeddings(...) or"
|
" respective methods of the wrapped objects (model.encoder.resize_token_embeddings(...) or"
|
||||||
" model.decoder.resize_token_embeddings(...))"
|
" model.decoder.resize_token_embeddings(...))"
|
||||||
)
|
)
|
||||||
|
|
||||||
def _reorder_cache(self, past, beam_idx):
|
|
||||||
# apply decoder cache reordering here
|
|
||||||
return self.decoder._reorder_cache(past, beam_idx)
|
|
||||||
|
@ -2538,16 +2538,6 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
|
|||||||
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
|
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
|
||||||
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
|
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _reorder_cache(past, beam_idx):
|
|
||||||
reordered_past = ()
|
|
||||||
for layer_past in past:
|
|
||||||
# cached cross_attention states don't have to be reordered -> they are always the same
|
|
||||||
reordered_past += (
|
|
||||||
tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:],
|
|
||||||
)
|
|
||||||
return reordered_past
|
|
||||||
|
|
||||||
def hf_compute_loss(self, labels, logits):
|
def hf_compute_loss(self, labels, logits):
|
||||||
"""CrossEntropyLoss that ignores pad tokens"""
|
"""CrossEntropyLoss that ignores pad tokens"""
|
||||||
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
|
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
|
||||||
|
@ -1494,17 +1494,6 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
|
|||||||
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
|
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
|
||||||
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
|
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration._reorder_cache
|
|
||||||
def _reorder_cache(past, beam_idx):
|
|
||||||
reordered_past = ()
|
|
||||||
for layer_past in past:
|
|
||||||
# cached cross_attention states don't have to be reordered -> they are always the same
|
|
||||||
reordered_past += (
|
|
||||||
tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:],
|
|
||||||
)
|
|
||||||
return reordered_past
|
|
||||||
|
|
||||||
def adjust_logits_during_generation(
|
def adjust_logits_during_generation(
|
||||||
self, logits, cur_len, max_length, forced_bos_token_id, forced_eos_token_id, **kwargs
|
self, logits, cur_len, max_length, forced_bos_token_id, forced_eos_token_id, **kwargs
|
||||||
):
|
):
|
||||||
|
@ -1490,14 +1490,3 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
|
|||||||
|
|
||||||
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
|
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
|
||||||
return shift_tokens_right(labels, self.config.pad_token_id)
|
return shift_tokens_right(labels, self.config.pad_token_id)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration._reorder_cache
|
|
||||||
def _reorder_cache(past, beam_idx):
|
|
||||||
reordered_past = ()
|
|
||||||
for layer_past in past:
|
|
||||||
# cached cross_attention states don't have to be reordered -> they are always the same
|
|
||||||
reordered_past += (
|
|
||||||
tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:],
|
|
||||||
)
|
|
||||||
return reordered_past
|
|
||||||
|
@ -1503,14 +1503,3 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
|
|||||||
|
|
||||||
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
|
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
|
||||||
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
|
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration._reorder_cache
|
|
||||||
def _reorder_cache(past, beam_idx):
|
|
||||||
reordered_past = ()
|
|
||||||
for layer_past in past:
|
|
||||||
# cached cross_attention states don't have to be reordered -> they are always the same
|
|
||||||
reordered_past += (
|
|
||||||
tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:],
|
|
||||||
)
|
|
||||||
return reordered_past
|
|
||||||
|
@ -799,24 +799,6 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
|
|||||||
def question_encoder(self):
|
def question_encoder(self):
|
||||||
return self.rag.question_encoder
|
return self.rag.question_encoder
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _reorder_cache(past, beam_idx):
|
|
||||||
"""Reorders cache for generation. BART-inspired but we need to take care of the extra dimension for docs"""
|
|
||||||
|
|
||||||
def _reorder_stacked(hidden_states, new_order):
|
|
||||||
n_docs = hidden_states.shape[0] // new_order.shape[0]
|
|
||||||
hidden_states = tf.reshape(hidden_states, (-1, n_docs, *hidden_states.shape[1:]))
|
|
||||||
hidden_states = tf.gather(hidden_states, new_order, axis=0)
|
|
||||||
result = tf.reshape(hidden_states, (-1, *hidden_states.shape[2:]))
|
|
||||||
return result
|
|
||||||
|
|
||||||
reordered_past = ()
|
|
||||||
for layer_past in past:
|
|
||||||
# get the correct batch idx from decoder layer's batch dim for cross and self-attn
|
|
||||||
reordered_past += (tuple(_reorder_stacked(past_state, beam_idx) for past_state in layer_past),)
|
|
||||||
|
|
||||||
return reordered_past
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _gather_beams(nested, beam_indices, batch_axis=0):
|
def _gather_beams(nested, beam_indices, batch_axis=0):
|
||||||
"""
|
"""
|
||||||
|
@ -1244,14 +1244,6 @@ class TFRemBertForCausalLM(TFRemBertPreTrainedModel, TFCausalLanguageModelingLos
|
|||||||
logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns
|
logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel._reorder_cache
|
|
||||||
def _reorder_cache(past, beam_idx):
|
|
||||||
reordered_past = ()
|
|
||||||
for layer_past in past:
|
|
||||||
reordered_past += (tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past),)
|
|
||||||
return reordered_past
|
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"""
|
"""
|
||||||
|
@ -1286,14 +1286,6 @@ class TFRobertaForCausalLM(TFRobertaPreTrainedModel, TFCausalLanguageModelingLos
|
|||||||
logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns
|
logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel._reorder_cache
|
|
||||||
def _reorder_cache(past, beam_idx):
|
|
||||||
reordered_past = ()
|
|
||||||
for layer_past in past:
|
|
||||||
reordered_past += (tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past),)
|
|
||||||
return reordered_past
|
|
||||||
|
|
||||||
|
|
||||||
class TFRobertaClassificationHead(tf.keras.layers.Layer):
|
class TFRobertaClassificationHead(tf.keras.layers.Layer):
|
||||||
"""Head for sentence-level classification tasks."""
|
"""Head for sentence-level classification tasks."""
|
||||||
|
@ -1301,14 +1301,6 @@ class TFRobertaPreLayerNormForCausalLM(TFRobertaPreLayerNormPreTrainedModel, TFC
|
|||||||
logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns
|
logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel._reorder_cache
|
|
||||||
def _reorder_cache(past, beam_idx):
|
|
||||||
reordered_past = ()
|
|
||||||
for layer_past in past:
|
|
||||||
reordered_past += (tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past),)
|
|
||||||
return reordered_past
|
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaClassificationHead with Roberta->RobertaPreLayerNorm
|
# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaClassificationHead with Roberta->RobertaPreLayerNorm
|
||||||
class TFRobertaPreLayerNormClassificationHead(tf.keras.layers.Layer):
|
class TFRobertaPreLayerNormClassificationHead(tf.keras.layers.Layer):
|
||||||
|
@ -1501,10 +1501,3 @@ class TFSpeech2TextForConditionalGeneration(TFSpeech2TextPreTrainedModel, TFCaus
|
|||||||
"cross_attn_head_mask": cross_attn_head_mask,
|
"cross_attn_head_mask": cross_attn_head_mask,
|
||||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||||
}
|
}
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _reorder_cache(past, beam_idx):
|
|
||||||
reordered_past = ()
|
|
||||||
for layer_past in past:
|
|
||||||
reordered_past += (tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past),)
|
|
||||||
return reordered_past
|
|
||||||
|
@ -1528,30 +1528,6 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
|||||||
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
|
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
|
||||||
return self._shift_right(labels)
|
return self._shift_right(labels)
|
||||||
|
|
||||||
def _reorder_cache(self, past, beam_idx):
|
|
||||||
# if decoder past is not included in output
|
|
||||||
# speedy decoding is disabled and no need to reorder
|
|
||||||
if past is None:
|
|
||||||
logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
|
|
||||||
return past
|
|
||||||
|
|
||||||
reordered_decoder_past = ()
|
|
||||||
for layer_past_states in past:
|
|
||||||
# get the correct batch idx from layer past batch dim
|
|
||||||
# batch dim of `past` is at 2nd position
|
|
||||||
reordered_layer_past_states = ()
|
|
||||||
for layer_past_state in layer_past_states:
|
|
||||||
# need to set correct `past` for each of the four key / value states
|
|
||||||
reordered_layer_past_states = reordered_layer_past_states + (
|
|
||||||
tf.gather(layer_past_state, beam_idx, axis=0),
|
|
||||||
)
|
|
||||||
|
|
||||||
assert reordered_layer_past_states[0].shape == layer_past_states[0].shape
|
|
||||||
assert len(reordered_layer_past_states) == len(layer_past_states)
|
|
||||||
|
|
||||||
reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
|
|
||||||
return reordered_decoder_past
|
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"The bare T5 Model transformer outputting encoder's raw hidden-stateswithout any specific head on top.",
|
"The bare T5 Model transformer outputting encoder's raw hidden-stateswithout any specific head on top.",
|
||||||
|
@ -1039,10 +1039,6 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
|
|||||||
|
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _reorder_cache(mems: List[tf.Tensor], beam_idx: tf.Tensor) -> List[tf.Tensor]:
|
|
||||||
return [tf.gather(layer_past, beam_idx, axis=1) for layer_past in mems]
|
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"""
|
"""
|
||||||
|
@ -756,7 +756,3 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
|
|||||||
"Resizing the embedding layers via the TFVisionEncoderDecoderModel directly is not supported."
|
"Resizing the embedding layers via the TFVisionEncoderDecoderModel directly is not supported."
|
||||||
"Please use the respective methods of the wrapped objects (model.decoder.resize_token_embeddings(...))"
|
"Please use the respective methods of the wrapped objects (model.decoder.resize_token_embeddings(...))"
|
||||||
)
|
)
|
||||||
|
|
||||||
def _reorder_cache(self, past, beam_idx):
|
|
||||||
# apply decoder cache reordering here
|
|
||||||
return self.decoder._reorder_cache(past, beam_idx)
|
|
||||||
|
@ -1386,11 +1386,3 @@ class TFWhisperForConditionalGeneration(TFWhisperPreTrainedModel, TFCausalLangua
|
|||||||
"decoder_attention_mask": decoder_attention_mask,
|
"decoder_attention_mask": decoder_attention_mask,
|
||||||
"decoder_position_ids": decoder_position_ids,
|
"decoder_position_ids": decoder_position_ids,
|
||||||
}
|
}
|
||||||
|
|
||||||
#
|
|
||||||
@staticmethod
|
|
||||||
def _reorder_cache(past, beam_idx):
|
|
||||||
reordered_past = ()
|
|
||||||
for layer_past in past:
|
|
||||||
reordered_past += (tuple(tf.gather(past_state, beam_idx) for past_state in layer_past),)
|
|
||||||
return reordered_past
|
|
||||||
|
@ -992,10 +992,3 @@ class TFXGLMForCausalLM(TFXGLMPreTrainedModel, TFCausalLanguageModelingLoss):
|
|||||||
attentions=attns,
|
attentions=attns,
|
||||||
cross_attentions=cross_attns,
|
cross_attentions=cross_attns,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _reorder_cache(past, beam_idx):
|
|
||||||
reordered_past = ()
|
|
||||||
for layer_past in past:
|
|
||||||
reordered_past += (tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past),)
|
|
||||||
return reordered_past
|
|
||||||
|
@ -3028,13 +3028,6 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
|
|||||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||||
}
|
}
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _reorder_cache(past, beam_idx):
|
|
||||||
reordered_past = ()
|
|
||||||
for layer_past in past:
|
|
||||||
reordered_past += (tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past),)
|
|
||||||
return reordered_past
|
|
||||||
|
|
||||||
def hf_compute_loss(self, labels, logits):
|
def hf_compute_loss(self, labels, logits):
|
||||||
"""CrossEntropyLoss that ignores pad tokens"""
|
"""CrossEntropyLoss that ignores pad tokens"""
|
||||||
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
|
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
|
||||||
|
Loading…
Reference in New Issue
Block a user