mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
TF: remove set_tensor_by_indices_to_value (#16729)
This commit is contained in:
parent
a315988bae
commit
d7f7f29f29
@ -19,7 +19,6 @@ from typing import List
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from .tf_utils import set_tensor_by_indices_to_value
|
||||
from .utils import add_start_docstrings
|
||||
from .utils.logging import get_logger
|
||||
|
||||
@ -221,7 +220,7 @@ class TFMinLengthLogitsProcessor(TFLogitsProcessor):
|
||||
# generate is not XLA - compileable anyways
|
||||
if cur_len < self.min_length:
|
||||
eos_token_id_mask = tf.broadcast_to(tf.range(scores.shape[-1]) == self.eos_token_id, scores.shape)
|
||||
scores = set_tensor_by_indices_to_value(scores, eos_token_id_mask, float("-inf"))
|
||||
scores = tf.where(eos_token_id_mask, float("-inf"), scores)
|
||||
|
||||
return scores
|
||||
|
||||
@ -339,9 +338,7 @@ class TFNoBadWordsLogitsProcessor(TFLogitsProcessor):
|
||||
[True if token in banned_tokens_slice else False for token in range(vocab_size)]
|
||||
)
|
||||
|
||||
scores = set_tensor_by_indices_to_value(
|
||||
scores, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf")
|
||||
)
|
||||
scores = tf.where(tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf"), scores)
|
||||
|
||||
return scores
|
||||
|
||||
@ -397,9 +394,7 @@ class TFNoRepeatNGramLogitsProcessor(TFLogitsProcessor):
|
||||
[True if token in banned_tokens_slice else False for token in range(vocab_size)]
|
||||
)
|
||||
|
||||
scores = set_tensor_by_indices_to_value(
|
||||
scores, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf")
|
||||
)
|
||||
scores = tf.where(tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf"), scores)
|
||||
|
||||
return scores
|
||||
|
||||
|
@ -34,7 +34,7 @@ from .generation_tf_logits_process import (
|
||||
TFTopKLogitsWarper,
|
||||
TFTopPLogitsWarper,
|
||||
)
|
||||
from .tf_utils import set_tensor_by_indices_to_value, shape_list
|
||||
from .tf_utils import shape_list
|
||||
from .utils import ModelOutput, logging
|
||||
|
||||
|
||||
@ -952,8 +952,7 @@ class TFGenerationMixin:
|
||||
[True if token == eos_token_id else False for token in range(vocab_size)], dtype=tf.bool
|
||||
)
|
||||
eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [num_batch_hypotheses, vocab_size])
|
||||
|
||||
scores = set_tensor_by_indices_to_value(scores, eos_token_indices_mask, -float("inf"))
|
||||
scores = tf.where(eos_token_indices_mask, -float("inf"), scores)
|
||||
|
||||
if no_repeat_ngram_size > 0:
|
||||
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
|
||||
@ -969,8 +968,8 @@ class TFGenerationMixin:
|
||||
[True if token in banned_tokens_slice else False for token in range(vocab_size)]
|
||||
)
|
||||
|
||||
scores = set_tensor_by_indices_to_value(
|
||||
scores, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf")
|
||||
scores = tf.where(
|
||||
tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf"), scores
|
||||
)
|
||||
|
||||
if bad_words_ids is not None:
|
||||
@ -983,8 +982,8 @@ class TFGenerationMixin:
|
||||
[True if token in banned_tokens_slice else False for token in range(vocab_size)]
|
||||
)
|
||||
|
||||
scores = set_tensor_by_indices_to_value(
|
||||
scores, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf")
|
||||
scores = tf.where(
|
||||
tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf"), scores
|
||||
)
|
||||
|
||||
assert shape_list(scores) == [batch_size * num_beams, vocab_size]
|
||||
@ -2950,7 +2949,7 @@ def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("In
|
||||
top_k = min(max(top_k, min_tokens_to_keep), logits_shape[-1]) # Safety check
|
||||
# Remove all tokens with a probability less than the last token of the top-k
|
||||
indices_to_remove = logits < tf.math.top_k(logits, k=top_k)[0][..., -1, None]
|
||||
logits = set_tensor_by_indices_to_value(logits, indices_to_remove, filter_value)
|
||||
logits = tf.where(indices_to_remove, filter_value, logits)
|
||||
if top_p < 1.0:
|
||||
sorted_indices = tf.argsort(logits, direction="DESCENDING")
|
||||
sorted_logits = tf.gather(
|
||||
@ -2979,7 +2978,7 @@ def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("In
|
||||
)
|
||||
# scatter sorted tensors to original indexing
|
||||
indices_to_remove = scatter_values_on_batch_indices(sorted_indices_to_remove, sorted_indices)
|
||||
logits = set_tensor_by_indices_to_value(logits, indices_to_remove, filter_value)
|
||||
logits = tf.where(indices_to_remove, filter_value, logits)
|
||||
return logits
|
||||
|
||||
|
||||
|
@ -23,11 +23,6 @@ from .utils import logging
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def set_tensor_by_indices_to_value(tensor: tf.Tensor, indices: tf.Tensor, value: Union[tf.Tensor, int, float]):
|
||||
# create value_tensor since tensor value assignment is not possible in TF
|
||||
return tf.where(indices, value, tensor)
|
||||
|
||||
|
||||
def shape_list(tensor: Union[tf.Tensor, np.ndarray]) -> List[int]:
|
||||
"""
|
||||
Deal with dynamic shape in tensorflow cleanly.
|
||||
|
@ -37,7 +37,6 @@ if is_tf_available():
|
||||
TFTopKLogitsWarper,
|
||||
TFTopPLogitsWarper,
|
||||
)
|
||||
from transformers.tf_utils import set_tensor_by_indices_to_value
|
||||
|
||||
from ..test_modeling_tf_common import ids_tensor
|
||||
|
||||
@ -112,9 +111,9 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
||||
scores = self._get_uniform_logits(batch_size=2, length=vocab_size)
|
||||
|
||||
mask = tf.cast(tf.constant([[1] + 9 * [0], 10 * [0]]), tf.bool)
|
||||
scores = set_tensor_by_indices_to_value(scores, mask, -1 / vocab_size)
|
||||
scores = tf.where(mask, -1 / vocab_size, scores)
|
||||
mask = tf.cast(tf.constant([10 * [0], 5 * [0] + [1] + 4 * [0]]), tf.bool)
|
||||
scores = set_tensor_by_indices_to_value(scores, mask, 4 / vocab_size)
|
||||
scores = tf.where(mask, 4 / vocab_size, scores)
|
||||
|
||||
rep_penalty_proc = TFRepetitionPenaltyLogitsProcessor(penalty=2.0)
|
||||
|
||||
@ -340,8 +339,8 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
||||
scores_comp = processor(input_ids, scores_comp, cur_len=cur_len)
|
||||
|
||||
# remove inf
|
||||
scores = set_tensor_by_indices_to_value(scores, tf.math.is_inf(scores), -1e9)
|
||||
scores_comp = set_tensor_by_indices_to_value(scores_comp, tf.math.is_inf(scores_comp), -1e9)
|
||||
scores = tf.where(tf.math.is_inf(scores), -1e9, scores)
|
||||
scores_comp = tf.where(tf.math.is_inf(scores_comp), -1e9, scores_comp)
|
||||
|
||||
# scores should be equal
|
||||
tf.debugging.assert_near(scores, scores_comp, atol=1e-3)
|
||||
|
Loading…
Reference in New Issue
Block a user