Generate: remove Marian hack (#25294)

Remove Marian hack
This commit is contained in:
Joao Gante 2023-08-07 15:38:24 +01:00 committed by GitHub
parent 145109382a
commit 7d65697da7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 0 additions and 58 deletions

View File

@ -474,27 +474,6 @@ class TFGenerationMixin:
"A model class needs to define a `prepare_inputs_for_generation` method in order to use `generate`."
)
def adjust_logits_during_generation(
self, logits, cur_len, max_length, forced_bos_token_id, forced_eos_token_id, **kwargs
):
"""
Implement in subclasses of [`PreTrainedModel`] for custom behavior to adjust the logits in the generate method.
"""
vocab_size = getattr(self.config, "vocab_size", None)
if vocab_size is None and self.config.is_encoder_decoder:
decoder_config = getattr(self.config, "decoder", None)
if decoder_config is not None:
vocab_size = getattr(self.config.decoder, "vocab_size", None)
if cur_len == 1 and forced_bos_token_id is not None:
vocab_range = tf.constant(range(vocab_size))
return tf.where(vocab_range != forced_bos_token_id, -1e8, logits)
elif cur_len == max_length - 1 and forced_eos_token_id is not None:
vocab_range = tf.constant(range(vocab_size))
return tf.where(vocab_range != forced_eos_token_id, -1e8, logits)
else:
return logits
def compute_transition_scores(
self,
sequences: tf.Tensor,

View File

@ -578,12 +578,6 @@ class GenerationMixin:
inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs)
return inputs, input_name, model_kwargs
def adjust_logits_during_generation(self, logits: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
"""
Implement in subclasses of [`PreTrainedModel`] for custom behavior to adjust the logits in the generate method.
"""
return logits
def _maybe_initialize_input_ids_for_generation(
self,
inputs: Optional[torch.Tensor] = None,
@ -3060,9 +3054,6 @@ class GenerationMixin:
continue # don't waste resources running the code we don't need
next_token_logits = outputs.logits[:, -1, :]
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `nn.functional.log_softmax` operation.
next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
next_token_scores = nn.functional.log_softmax(
next_token_logits, dim=-1
) # (batch_size * num_beams, vocab_size)
@ -3388,9 +3379,6 @@ class GenerationMixin:
next_token_logits = outputs.logits[:, -1, :]
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `nn.functional.log_softmax` operation.
next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
next_token_scores = nn.functional.log_softmax(
next_token_logits, dim=-1
) # (batch_size * num_beams, vocab_size)
@ -3751,9 +3739,6 @@ class GenerationMixin:
# select outputs of beams of current group only
next_token_logits = outputs.logits[batch_group_indices, -1, :]
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `nn.functional.log_softmax` operation.
next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
next_token_scores = nn.functional.log_softmax(
next_token_logits, dim=-1
) # (batch_size * group_size, vocab_size)
@ -4110,9 +4095,6 @@ class GenerationMixin:
continue # don't waste resources running the code we don't need
next_token_logits = outputs.logits[:, -1, :]
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `nn.functional.log_softmax` operation.
next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
next_token_scores = nn.functional.log_softmax(
next_token_logits, dim=-1
) # (batch_size * num_beams, vocab_size)

View File

@ -1524,10 +1524,6 @@ class MarianMTModel(MarianPreTrainedModel):
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
def adjust_logits_during_generation(self, logits, cur_len):
logits[:, self.config.pad_token_id] = float("-inf") # never predict pad token.
return logits
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()

View File

@ -1443,18 +1443,3 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
def adjust_logits_during_generation(
self, logits, cur_len, max_length, forced_bos_token_id, forced_eos_token_id, **kwargs
):
"""Never predict pad_token_id. Predict </s> when max_length is reached."""
vocab_range = tf.constant(range(self.config.vocab_size))
logits = tf.where(vocab_range == self.config.pad_token_id, LARGE_NEGATIVE, logits)
if cur_len == 1 and forced_bos_token_id is not None:
vocab_range = tf.constant(range(self.config.vocab_size))
return tf.where(vocab_range != forced_bos_token_id, LARGE_NEGATIVE, logits)
elif cur_len == max_length - 1 and forced_eos_token_id is not None:
vocab_range = tf.constant(range(self.config.vocab_size))
return tf.where(vocab_range != forced_eos_token_id, LARGE_NEGATIVE, logits)
else:
return logits