TF: remove set_tensor_by_indices_to_value (#16729)

This commit is contained in:
Joao Gante 2022-04-12 17:51:47 +01:00 committed by GitHub
parent a315988bae
commit d7f7f29f29
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 15 additions and 27 deletions

View File

@ -19,7 +19,6 @@ from typing import List
import numpy as np
import tensorflow as tf
from .tf_utils import set_tensor_by_indices_to_value
from .utils import add_start_docstrings
from .utils.logging import get_logger
@ -221,7 +220,7 @@ class TFMinLengthLogitsProcessor(TFLogitsProcessor):
# generate is not XLA - compileable anyways
if cur_len < self.min_length:
eos_token_id_mask = tf.broadcast_to(tf.range(scores.shape[-1]) == self.eos_token_id, scores.shape)
scores = set_tensor_by_indices_to_value(scores, eos_token_id_mask, float("-inf"))
scores = tf.where(eos_token_id_mask, float("-inf"), scores)
return scores
@ -339,9 +338,7 @@ class TFNoBadWordsLogitsProcessor(TFLogitsProcessor):
[True if token in banned_tokens_slice else False for token in range(vocab_size)]
)
scores = set_tensor_by_indices_to_value(
scores, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf")
)
scores = tf.where(tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf"), scores)
return scores
@ -397,9 +394,7 @@ class TFNoRepeatNGramLogitsProcessor(TFLogitsProcessor):
[True if token in banned_tokens_slice else False for token in range(vocab_size)]
)
scores = set_tensor_by_indices_to_value(
scores, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf")
)
scores = tf.where(tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf"), scores)
return scores

View File

@ -34,7 +34,7 @@ from .generation_tf_logits_process import (
TFTopKLogitsWarper,
TFTopPLogitsWarper,
)
from .tf_utils import set_tensor_by_indices_to_value, shape_list
from .tf_utils import shape_list
from .utils import ModelOutput, logging
@ -952,8 +952,7 @@ class TFGenerationMixin:
[True if token == eos_token_id else False for token in range(vocab_size)], dtype=tf.bool
)
eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [num_batch_hypotheses, vocab_size])
scores = set_tensor_by_indices_to_value(scores, eos_token_indices_mask, -float("inf"))
scores = tf.where(eos_token_indices_mask, -float("inf"), scores)
if no_repeat_ngram_size > 0:
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
@ -969,8 +968,8 @@ class TFGenerationMixin:
[True if token in banned_tokens_slice else False for token in range(vocab_size)]
)
scores = set_tensor_by_indices_to_value(
scores, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf")
scores = tf.where(
tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf"), scores
)
if bad_words_ids is not None:
@ -983,8 +982,8 @@ class TFGenerationMixin:
[True if token in banned_tokens_slice else False for token in range(vocab_size)]
)
scores = set_tensor_by_indices_to_value(
scores, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf")
scores = tf.where(
tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf"), scores
)
assert shape_list(scores) == [batch_size * num_beams, vocab_size]
@ -2950,7 +2949,7 @@ def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("In
top_k = min(max(top_k, min_tokens_to_keep), logits_shape[-1]) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < tf.math.top_k(logits, k=top_k)[0][..., -1, None]
logits = set_tensor_by_indices_to_value(logits, indices_to_remove, filter_value)
logits = tf.where(indices_to_remove, filter_value, logits)
if top_p < 1.0:
sorted_indices = tf.argsort(logits, direction="DESCENDING")
sorted_logits = tf.gather(
@ -2979,7 +2978,7 @@ def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("In
)
# scatter sorted tensors to original indexing
indices_to_remove = scatter_values_on_batch_indices(sorted_indices_to_remove, sorted_indices)
logits = set_tensor_by_indices_to_value(logits, indices_to_remove, filter_value)
logits = tf.where(indices_to_remove, filter_value, logits)
return logits

View File

@ -23,11 +23,6 @@ from .utils import logging
logger = logging.get_logger(__name__)
def set_tensor_by_indices_to_value(tensor: tf.Tensor, indices: tf.Tensor, value: Union[tf.Tensor, int, float]):
# create value_tensor since tensor value assignment is not possible in TF
return tf.where(indices, value, tensor)
def shape_list(tensor: Union[tf.Tensor, np.ndarray]) -> List[int]:
"""
Deal with dynamic shape in tensorflow cleanly.

View File

@ -37,7 +37,6 @@ if is_tf_available():
TFTopKLogitsWarper,
TFTopPLogitsWarper,
)
from transformers.tf_utils import set_tensor_by_indices_to_value
from ..test_modeling_tf_common import ids_tensor
@ -112,9 +111,9 @@ class TFLogitsProcessorTest(unittest.TestCase):
scores = self._get_uniform_logits(batch_size=2, length=vocab_size)
mask = tf.cast(tf.constant([[1] + 9 * [0], 10 * [0]]), tf.bool)
scores = set_tensor_by_indices_to_value(scores, mask, -1 / vocab_size)
scores = tf.where(mask, -1 / vocab_size, scores)
mask = tf.cast(tf.constant([10 * [0], 5 * [0] + [1] + 4 * [0]]), tf.bool)
scores = set_tensor_by_indices_to_value(scores, mask, 4 / vocab_size)
scores = tf.where(mask, 4 / vocab_size, scores)
rep_penalty_proc = TFRepetitionPenaltyLogitsProcessor(penalty=2.0)
@ -340,8 +339,8 @@ class TFLogitsProcessorTest(unittest.TestCase):
scores_comp = processor(input_ids, scores_comp, cur_len=cur_len)
# remove inf
scores = set_tensor_by_indices_to_value(scores, tf.math.is_inf(scores), -1e9)
scores_comp = set_tensor_by_indices_to_value(scores_comp, tf.math.is_inf(scores_comp), -1e9)
scores = tf.where(tf.math.is_inf(scores), -1e9, scores)
scores_comp = tf.where(tf.math.is_inf(scores_comp), -1e9, scores_comp)
# scores should be equal
tf.debugging.assert_near(scores, scores_comp, atol=1e-3)