From 7d65697da746066aa75238347d8c86bde1acbf1b Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Mon, 7 Aug 2023 15:38:24 +0100 Subject: [PATCH] Generate: remove Marian hack (#25294) Remove Marian hack --- src/transformers/generation/tf_utils.py | 21 ------------------- src/transformers/generation/utils.py | 18 ---------------- .../models/marian/modeling_marian.py | 4 ---- .../models/marian/modeling_tf_marian.py | 15 ------------- 4 files changed, 58 deletions(-) diff --git a/src/transformers/generation/tf_utils.py b/src/transformers/generation/tf_utils.py index dcb407f5707..648ec710cfe 100644 --- a/src/transformers/generation/tf_utils.py +++ b/src/transformers/generation/tf_utils.py @@ -474,27 +474,6 @@ class TFGenerationMixin: "A model class needs to define a `prepare_inputs_for_generation` method in order to use `generate`." ) - def adjust_logits_during_generation( - self, logits, cur_len, max_length, forced_bos_token_id, forced_eos_token_id, **kwargs - ): - """ - Implement in subclasses of [`PreTrainedModel`] for custom behavior to adjust the logits in the generate method. - """ - vocab_size = getattr(self.config, "vocab_size", None) - if vocab_size is None and self.config.is_encoder_decoder: - decoder_config = getattr(self.config, "decoder", None) - if decoder_config is not None: - vocab_size = getattr(self.config.decoder, "vocab_size", None) - - if cur_len == 1 and forced_bos_token_id is not None: - vocab_range = tf.constant(range(vocab_size)) - return tf.where(vocab_range != forced_bos_token_id, -1e8, logits) - elif cur_len == max_length - 1 and forced_eos_token_id is not None: - vocab_range = tf.constant(range(vocab_size)) - return tf.where(vocab_range != forced_eos_token_id, -1e8, logits) - else: - return logits - def compute_transition_scores( self, sequences: tf.Tensor, diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index aca7bdd6f45..29d636b3c63 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -578,12 +578,6 @@ class GenerationMixin: inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs) return inputs, input_name, model_kwargs - def adjust_logits_during_generation(self, logits: torch.FloatTensor, **kwargs) -> torch.FloatTensor: - """ - Implement in subclasses of [`PreTrainedModel`] for custom behavior to adjust the logits in the generate method. - """ - return logits - def _maybe_initialize_input_ids_for_generation( self, inputs: Optional[torch.Tensor] = None, @@ -3060,9 +3054,6 @@ class GenerationMixin: continue # don't waste resources running the code we don't need next_token_logits = outputs.logits[:, -1, :] - # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` - # cannot be generated both before and after the `nn.functional.log_softmax` operation. - next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) next_token_scores = nn.functional.log_softmax( next_token_logits, dim=-1 ) # (batch_size * num_beams, vocab_size) @@ -3388,9 +3379,6 @@ class GenerationMixin: next_token_logits = outputs.logits[:, -1, :] - # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` - # cannot be generated both before and after the `nn.functional.log_softmax` operation. - next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) next_token_scores = nn.functional.log_softmax( next_token_logits, dim=-1 ) # (batch_size * num_beams, vocab_size) @@ -3751,9 +3739,6 @@ class GenerationMixin: # select outputs of beams of current group only next_token_logits = outputs.logits[batch_group_indices, -1, :] - # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` - # cannot be generated both before and after the `nn.functional.log_softmax` operation. - next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) next_token_scores = nn.functional.log_softmax( next_token_logits, dim=-1 ) # (batch_size * group_size, vocab_size) @@ -4110,9 +4095,6 @@ class GenerationMixin: continue # don't waste resources running the code we don't need next_token_logits = outputs.logits[:, -1, :] - # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` - # cannot be generated both before and after the `nn.functional.log_softmax` operation. - next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) next_token_scores = nn.functional.log_softmax( next_token_logits, dim=-1 ) # (batch_size * num_beams, vocab_size) diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 55f8127dc9d..bad7ff71944 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -1524,10 +1524,6 @@ class MarianMTModel(MarianPreTrainedModel): def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) - def adjust_logits_during_generation(self, logits, cur_len): - logits[:, self.config.pad_token_id] = float("-inf") # never predict pad token. - return logits - @staticmethod def _reorder_cache(past_key_values, beam_idx): reordered_past = () diff --git a/src/transformers/models/marian/modeling_tf_marian.py b/src/transformers/models/marian/modeling_tf_marian.py index 8e0cecb99b5..f163c821713 100644 --- a/src/transformers/models/marian/modeling_tf_marian.py +++ b/src/transformers/models/marian/modeling_tf_marian.py @@ -1443,18 +1443,3 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss): def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor): return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) - - def adjust_logits_during_generation( - self, logits, cur_len, max_length, forced_bos_token_id, forced_eos_token_id, **kwargs - ): - """Never predict pad_token_id. Predict when max_length is reached.""" - vocab_range = tf.constant(range(self.config.vocab_size)) - logits = tf.where(vocab_range == self.config.pad_token_id, LARGE_NEGATIVE, logits) - if cur_len == 1 and forced_bos_token_id is not None: - vocab_range = tf.constant(range(self.config.vocab_size)) - return tf.where(vocab_range != forced_bos_token_id, LARGE_NEGATIVE, logits) - elif cur_len == max_length - 1 and forced_eos_token_id is not None: - vocab_range = tf.constant(range(self.config.vocab_size)) - return tf.where(vocab_range != forced_eos_token_id, LARGE_NEGATIVE, logits) - else: - return logits