mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Adding prepare_decoder_input_ids_from_labels methods to all ConditionalGeneration TF models (#12560)
This commit is contained in:
parent
ebc69afc30
commit
95425d546d
@ -1494,6 +1494,9 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
|
||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||
}
|
||||
|
||||
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
|
||||
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past, beam_idx):
|
||||
if len(past) == 1:
|
||||
|
@ -2522,6 +2522,9 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
|
||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||
}
|
||||
|
||||
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
|
||||
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past, beam_idx):
|
||||
if len(past) == 1:
|
||||
|
@ -1522,6 +1522,9 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||
}
|
||||
|
||||
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
|
||||
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration._reorder_cache
|
||||
def _reorder_cache(past, beam_idx):
|
||||
|
@ -1506,6 +1506,9 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
|
||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||
}
|
||||
|
||||
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
|
||||
return shift_tokens_right(labels, self.config.pad_token_id)
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration._reorder_cache
|
||||
def _reorder_cache(past, beam_idx):
|
||||
|
@ -1531,6 +1531,9 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
|
||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||
}
|
||||
|
||||
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
|
||||
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration._reorder_cache
|
||||
def _reorder_cache(past, beam_idx):
|
||||
|
@ -1499,6 +1499,9 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
||||
"use_cache": use_cache,
|
||||
}
|
||||
|
||||
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
|
||||
return self._shift_right(labels)
|
||||
|
||||
def _reorder_cache(self, past, beam_idx) -> Tuple:
|
||||
# if decoder past is not included in output
|
||||
# speedy decoding is disabled and no need to reorder
|
||||
|
Loading…
Reference in New Issue
Block a user