mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
TF: XLA bad words logits processor and list of processors (#16974)
This commit is contained in:
parent
57e6464ac9
commit
fb0ae12947
@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import List
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
@ -38,7 +38,10 @@ TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
scores (`tf.Tensor` of shape `(batch_size, config.vocab_size)`):
|
||||
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam
|
||||
search or log softmax for each vocabulary token when using beam search
|
||||
search or log softmax for each vocabulary token when using beam search.
|
||||
cur_len (`int`):
|
||||
The current length of valid input sequence tokens. In the TF implementation, the input_ids' sequence length
|
||||
is the maximum length generate can produce, and we need to know which of its tokens are valid.
|
||||
kwargs:
|
||||
Additional logits processor specific kwargs.
|
||||
|
||||
@ -51,7 +54,7 @@ class TFLogitsProcessor:
|
||||
"""Abstract base class for all logit processors that can be applied during generation."""
|
||||
|
||||
@add_start_docstrings(TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
|
||||
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
|
||||
"""TF method for processing logits."""
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
|
||||
@ -62,7 +65,7 @@ class TFLogitsWarper:
|
||||
"""Abstract base class for all logit warpers that can be applied during generation with multinomial sampling."""
|
||||
|
||||
@add_start_docstrings(TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
|
||||
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
|
||||
"""TF method for warping logits."""
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
|
||||
@ -77,18 +80,18 @@ class TFLogitsProcessorList(list):
|
||||
"""
|
||||
|
||||
@add_start_docstrings(TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, **kwargs) -> tf.Tensor:
|
||||
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int, **kwargs) -> tf.Tensor:
|
||||
for processor in self:
|
||||
function_args = inspect.signature(processor.__call__).parameters
|
||||
if len(function_args) > 2:
|
||||
if len(function_args) > 3:
|
||||
if not all(arg in kwargs for arg in list(function_args.keys())[2:]):
|
||||
raise ValueError(
|
||||
f"Make sure that all the required parameters: {list(function_args.keys())} for "
|
||||
f"{processor.__class__} are passed to the logits processor."
|
||||
)
|
||||
scores = processor(input_ids, scores, **kwargs)
|
||||
scores = processor(input_ids, scores, cur_len, **kwargs)
|
||||
else:
|
||||
scores = processor(input_ids, scores)
|
||||
scores = processor(input_ids, scores, cur_len)
|
||||
return scores
|
||||
|
||||
|
||||
@ -107,7 +110,7 @@ class TFTemperatureLogitsWarper(TFLogitsWarper):
|
||||
|
||||
self.temperature = temperature
|
||||
|
||||
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
|
||||
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
|
||||
scores = scores / self.temperature
|
||||
return scores
|
||||
|
||||
@ -133,7 +136,7 @@ class TFTopKLogitsWarper(TFLogitsWarper):
|
||||
self.filter_value = filter_value
|
||||
self.min_tokens_to_keep = min_tokens_to_keep
|
||||
|
||||
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
|
||||
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
|
||||
top_k = min(max(self.top_k, self.min_tokens_to_keep), scores.shape[-1]) # Safety check
|
||||
# Boolean mask containing all tokens with a probability less than the last token of the top-k
|
||||
indices_to_remove = scores < tf.math.top_k(scores, k=top_k)[0][..., -1:]
|
||||
@ -163,7 +166,7 @@ class TFTopPLogitsWarper(TFLogitsWarper):
|
||||
self.filter_value = filter_value
|
||||
self.min_tokens_to_keep = min_tokens_to_keep
|
||||
|
||||
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
|
||||
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
|
||||
topk_scores, topk_indices = tf.math.top_k(scores, scores.shape[-1])
|
||||
|
||||
mask_scores = tf.fill(scores.shape, self.filter_value)
|
||||
@ -305,58 +308,75 @@ class TFNoBadWordsLogitsProcessor(TFLogitsProcessor):
|
||||
f"Each list in `bad_words_ids` has to be a list of positive integers, but is {bad_words_ids}."
|
||||
)
|
||||
|
||||
self.bad_words_ids = bad_words_ids
|
||||
# stores the information about bad words in three tensors:
|
||||
# 1. a rectangular tensor with the forbidden sequences (padded with `-1`), for full data comparisons
|
||||
self.bad_word_seqs_ids = tf.ragged.constant(bad_words_ids).to_tensor(default_value=-1)
|
||||
# 2. a tensor with the unpadded length of each forbidden sequence, for quick length comparisons
|
||||
bad_word_seqs_len = [len(bad_words) for bad_words in bad_words_ids]
|
||||
if any([word_len == 0 for word_len in bad_word_seqs_len]):
|
||||
raise ValueError(f"Banned words token sequences {bad_words_ids} cannot have an empty list")
|
||||
self.bad_word_seqs_len = tf.convert_to_tensor(bad_word_seqs_len, dtype=tf.int32)
|
||||
# 3. a tensor containing the last token for each sequence, for easy access to the tokens that may be banned
|
||||
self.seq_forbidden_tokens = tf.convert_to_tensor([bad_words[-1] for bad_words in bad_words_ids])
|
||||
|
||||
def calc_banned_bad_words_ids(self, prev_input_ids):
|
||||
banned_tokens = []
|
||||
def _calc_row_banned_bad_tokens(self, row_input_ids: tf.Tensor) -> tf.Tensor:
|
||||
def _tokens_match(bad_word_seq_number):
|
||||
def _len_one():
|
||||
# If the bad sequence only has one token, always mask it
|
||||
return tf.cond(
|
||||
tf.math.equal(self.bad_word_seqs_len[bad_word_seq_number], 1),
|
||||
lambda: tf.ones((), dtype=tf.bool),
|
||||
_len_greater_than_cur_len,
|
||||
)
|
||||
|
||||
def _tokens_match(prev_tokens, tokens):
|
||||
if len(tokens) == 0:
|
||||
# if bad word tokens is just one token always ban it
|
||||
return True
|
||||
if len(tokens) > len(prev_tokens):
|
||||
# if bad word tokens are longer than prev tokens they can't be equal
|
||||
return False
|
||||
def _len_greater_than_cur_len():
|
||||
# Otherwise, if the bad sequence is longer than the current length they can't ever match
|
||||
return tf.cond(
|
||||
tf.math.greater(self.bad_word_seqs_len[bad_word_seq_number], row_input_ids.shape[0]),
|
||||
lambda: tf.zeros((), dtype=tf.bool),
|
||||
_match_found,
|
||||
)
|
||||
|
||||
if prev_tokens[-len(tokens) :] == tokens:
|
||||
# if tokens match
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
def _match_found():
|
||||
# Finaly, runs the actual comparison. Can only be called if the previous comparisons do not yield
|
||||
# an answer (otherwise we get indexing exceptions)
|
||||
compare_len = self.bad_word_seqs_len[bad_word_seq_number] - 1
|
||||
return tf.cond(
|
||||
tf.math.reduce_all(
|
||||
tf.math.equal(
|
||||
row_input_ids[-compare_len:], self.bad_word_seqs_ids[bad_word_seq_number, :compare_len]
|
||||
)
|
||||
),
|
||||
lambda: tf.ones((), dtype=tf.bool),
|
||||
lambda: tf.zeros((), dtype=tf.bool),
|
||||
)
|
||||
|
||||
for prev_input_ids_slice in prev_input_ids:
|
||||
banned_tokens_slice = []
|
||||
match = _len_one()
|
||||
return match
|
||||
|
||||
for banned_token_seq in self.bad_words_ids:
|
||||
assert (
|
||||
len(banned_token_seq) > 0
|
||||
), f"Banned words token sequences {self.bad_words_ids} cannot have an empty list"
|
||||
|
||||
if _tokens_match(prev_input_ids_slice.numpy().tolist(), banned_token_seq[:-1]) is False:
|
||||
# if tokens do not match continue
|
||||
continue
|
||||
|
||||
banned_tokens_slice.append(banned_token_seq[-1])
|
||||
|
||||
banned_tokens.append(banned_tokens_slice)
|
||||
|
||||
return banned_tokens
|
||||
# Compares the current row against all bad word sequences, obtaining a mask with the matches.
|
||||
match_mask = tf.map_fn(_tokens_match, tf.range(self.bad_word_seqs_ids.shape[0]), fn_output_signature=tf.bool)
|
||||
row_banned_tokens = self.seq_forbidden_tokens[match_mask]
|
||||
return row_banned_tokens
|
||||
|
||||
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
|
||||
|
||||
vocab_size = scores.shape[-1]
|
||||
|
||||
# calculate a list of banned tokens according to bad words
|
||||
banned_tokens = self.calc_banned_bad_words_ids(input_ids[:, :cur_len])
|
||||
|
||||
banned_tokens_indices_mask = []
|
||||
for banned_tokens_slice in banned_tokens:
|
||||
banned_tokens_indices_mask.append(
|
||||
[True if token in banned_tokens_slice else False for token in range(vocab_size)]
|
||||
# We want to mask some banned tokens, at a score level. Since the banned tokens depend on the previous
|
||||
# `input_ids`, they may have a different length for each row, and they may even be empty for some rows.
|
||||
# To remain simple and XLA-compatible, we work on a per-row fashion.
|
||||
# TODO (Joao): this function might trigger XLA retracing as `cur_len` increases. Fix it if it becomes
|
||||
# a frequent choke point. (make `cur_len` a tensor?)
|
||||
def _get_row_updated_score(row_inputs: Tuple[tf.Tensor]) -> tf.Tensor:
|
||||
row_input_ids, row_score = row_inputs
|
||||
banned_tokens = self._calc_row_banned_bad_tokens(row_input_ids[:cur_len])
|
||||
banned_tokens_mask = tf.scatter_nd(
|
||||
indices=tf.expand_dims(banned_tokens, axis=-1),
|
||||
updates=tf.ones_like(banned_tokens, dtype=tf.bool),
|
||||
shape=row_score.shape,
|
||||
)
|
||||
row_score = tf.where(banned_tokens_mask, -float("inf"), row_score)
|
||||
return row_score
|
||||
|
||||
scores = tf.where(tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf"), scores)
|
||||
|
||||
scores = tf.map_fn(_get_row_updated_score, (input_ids, scores), fn_output_signature=tf.float32)
|
||||
return scores
|
||||
|
||||
|
||||
@ -401,6 +421,11 @@ class TFNoRepeatNGramLogitsProcessor(TFLogitsProcessor):
|
||||
|
||||
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
|
||||
|
||||
# TODO (joao): enable XLA on this logits processor. See discussion and attempts in
|
||||
# https://github.com/huggingface/transformers/pull/16974
|
||||
if not tf.executing_eagerly():
|
||||
raise NotImplementedError("TFNoRepeatNGramLogitsProcessor is only implemented for eager execution.")
|
||||
|
||||
batch_size, vocab_size = scores.shape
|
||||
banned_tokens = self.calc_banned_ngram_tokens(input_ids, batch_size, cur_len)
|
||||
|
||||
|
@ -2030,7 +2030,7 @@ class TFGenerationMixin:
|
||||
if not use_xla:
|
||||
input_ids = tf.reshape(generated.concat(), (-1, batch_size))
|
||||
input_ids = tf.transpose(input_ids[: current_pos[0]])
|
||||
next_tokens_scores = logits_processor(input_ids, next_token_logits, cur_len=current_pos[0])
|
||||
next_tokens_scores = logits_processor(input_ids, next_token_logits, current_pos[0])
|
||||
|
||||
# argmax
|
||||
next_tokens = tf.argmax(next_tokens_scores, axis=-1, output_type=tf.int32)
|
||||
@ -2301,8 +2301,8 @@ class TFGenerationMixin:
|
||||
if not use_xla:
|
||||
input_ids = tf.reshape(generated.concat(), (-1, batch_size))
|
||||
input_ids = tf.transpose(input_ids[:cur_len])
|
||||
next_tokens_scores = logits_processor(input_ids, next_token_logits, cur_len=cur_len)
|
||||
next_tokens_scores = logits_warper(input_ids, next_tokens_scores)
|
||||
next_tokens_scores = logits_processor(input_ids, next_token_logits, cur_len)
|
||||
next_tokens_scores = logits_warper(input_ids, next_tokens_scores, cur_len)
|
||||
|
||||
# sample
|
||||
if seed is not None:
|
||||
@ -2726,7 +2726,7 @@ class TFGenerationMixin:
|
||||
# add new logprobs to existing running logprobs scores.
|
||||
log_probs = tf.nn.log_softmax(logits)
|
||||
log_probs = logits_processor(
|
||||
flatten_beam_dim(running_sequences_seq_last), flatten_beam_dim(log_probs), cur_len=cur_len
|
||||
flatten_beam_dim(running_sequences_seq_last), flatten_beam_dim(log_probs), cur_len
|
||||
)
|
||||
log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams)
|
||||
log_probs = log_probs + tf.expand_dims(running_scores, axis=2)
|
||||
|
@ -75,6 +75,7 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
||||
@parameterized.expand([(False,), (True,)])
|
||||
def test_temperature_dist_warper(self, use_xla):
|
||||
input_ids = None
|
||||
cur_len = None
|
||||
length = 20
|
||||
|
||||
scores = self._get_uniform_logits(batch_size=2, length=length)
|
||||
@ -94,8 +95,8 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
||||
temp_dist_warper_sharper = tf.function(temp_dist_warper_sharper, jit_compile=True)
|
||||
temp_dist_warper_smoother = tf.function(temp_dist_warper_smoother, jit_compile=True)
|
||||
|
||||
warped_prob_sharp = tf.nn.softmax(temp_dist_warper_sharper(input_ids, tf.identity(scores)), axis=-1)
|
||||
warped_prob_smooth = tf.nn.softmax(temp_dist_warper_smoother(input_ids, tf.identity(scores)), axis=-1)
|
||||
warped_prob_sharp = tf.nn.softmax(temp_dist_warper_sharper(input_ids, tf.identity(scores), cur_len), axis=-1)
|
||||
warped_prob_smooth = tf.nn.softmax(temp_dist_warper_smoother(input_ids, tf.identity(scores), cur_len), axis=-1)
|
||||
|
||||
# uniform distribution stays uniform
|
||||
tf.debugging.assert_near(probs[0, :], warped_prob_sharp[0, :], atol=1e-3)
|
||||
@ -142,6 +143,7 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
||||
@parameterized.expand([(False,), (True,)])
|
||||
def test_top_k_dist_warper(self, use_xla):
|
||||
input_ids = None
|
||||
cur_len = None
|
||||
vocab_size = 10
|
||||
batch_size = 2
|
||||
|
||||
@ -153,7 +155,7 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
||||
if use_xla:
|
||||
top_k_warp = tf.function(top_k_warp, jit_compile=True)
|
||||
|
||||
scores = top_k_warp(input_ids, ramp_logits)
|
||||
scores = top_k_warp(input_ids, ramp_logits, cur_len)
|
||||
|
||||
# check that correct tokens are filtered
|
||||
self.assertListEqual(tf.math.is_inf(scores[0]).numpy().tolist(), 7 * [True] + 3 * [False])
|
||||
@ -167,12 +169,12 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
||||
if use_xla:
|
||||
top_k_warp_safety_check = tf.function(top_k_warp_safety_check, jit_compile=True)
|
||||
|
||||
scores = top_k_warp_safety_check(input_ids, logits)
|
||||
scores = top_k_warp_safety_check(input_ids, logits, cur_len)
|
||||
# uniform dist is not changed
|
||||
self.assertListEqual(tf.math.reduce_sum(tf.where(scores == 0.0, 1, 0), axis=-1).numpy().tolist(), [0, 0])
|
||||
|
||||
ramp_logits = np.broadcast_to(np.arange(length, dtype=np.float32), (batch_size, length)).copy()
|
||||
scores = top_k_warp_safety_check(input_ids, ramp_logits)
|
||||
scores = top_k_warp_safety_check(input_ids, ramp_logits, cur_len)
|
||||
|
||||
# min_tokens overwrites k: 3 tokens are kept => 2 tokens are nullified
|
||||
self.assertListEqual(tf.math.reduce_sum(tf.where(scores == 0.0, 1, 0), axis=-1).numpy().tolist(), [2, 2])
|
||||
@ -180,6 +182,7 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
||||
@parameterized.expand([(False,), (True,)])
|
||||
def test_top_p_dist_warper(self, use_xla):
|
||||
input_ids = None
|
||||
cur_len = None
|
||||
vocab_size = 10
|
||||
batch_size = 2
|
||||
|
||||
@ -189,7 +192,7 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
||||
top_p_warp = TFTopPLogitsWarper(0.7)
|
||||
if use_xla:
|
||||
top_p_warp = tf.function(top_p_warp, jit_compile=True)
|
||||
filtered_dist = tf.exp(top_p_warp(input_ids, dist))
|
||||
filtered_dist = tf.exp(top_p_warp(input_ids, dist, cur_len))
|
||||
|
||||
# dist should be filtered to keep min num values so that sum is >= 0.7
|
||||
# exp (-inf) => 0
|
||||
@ -208,7 +211,7 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
||||
top_p_warp = TFTopPLogitsWarper(0.9, min_tokens_to_keep=2, filter_value=0.0)
|
||||
if use_xla:
|
||||
top_p_warp = tf.function(top_p_warp, jit_compile=True)
|
||||
filtered_dist = top_p_warp(input_ids, ramp_logits)
|
||||
filtered_dist = top_p_warp(input_ids, ramp_logits, cur_len)
|
||||
|
||||
# first batch should keep three tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps
|
||||
# 2.
|
||||
@ -242,7 +245,8 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
||||
tf.math.is_inf(filtered_scores_3_gram).numpy().tolist(), [[False, False, False], [True, False, False]]
|
||||
)
|
||||
|
||||
def test_no_bad_words_dist_processor(self):
|
||||
@parameterized.expand([(False,), (True,)])
|
||||
def test_no_bad_words_dist_processor(self, use_xla):
|
||||
vocab_size = 5
|
||||
batch_size = 2
|
||||
eos_token_id = 4
|
||||
@ -255,6 +259,8 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
|
||||
no_bad_words_dist_proc = TFNoBadWordsLogitsProcessor(bad_words_ids=bad_word_tokens, eos_token_id=eos_token_id)
|
||||
if use_xla:
|
||||
no_bad_words_dist_proc = tf.function(no_bad_words_dist_proc, jit_compile=True)
|
||||
|
||||
filtered_scores = no_bad_words_dist_proc(input_ids, tf.identity(scores), cur_len)
|
||||
|
||||
@ -322,7 +328,9 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
||||
scores = logits_processor(input_ids, scores, cur_len)
|
||||
self.assertFalse(tf.math.reduce_any(tf.math.is_inf((scores))))
|
||||
|
||||
def test_processor_list(self):
|
||||
@parameterized.expand([(False,), (True,)])
|
||||
def test_processor_list(self, use_xla):
|
||||
# TODO (Joao): reintroduce TFNoRepeatNGramLogitsProcessor when it gets compatible with XLA
|
||||
batch_size = 4
|
||||
cur_len = 10
|
||||
vocab_size = 15
|
||||
@ -341,16 +349,24 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
||||
rep_penalty_proc = TFRepetitionPenaltyLogitsProcessor(penalty=2.0)
|
||||
top_k_warp = TFTopKLogitsWarper(3)
|
||||
top_p_warp = TFTopPLogitsWarper(0.8)
|
||||
no_repeat_proc = TFNoRepeatNGramLogitsProcessor(2)
|
||||
# no_repeat_proc = TFNoRepeatNGramLogitsProcessor(2)
|
||||
no_bad_words_dist_proc = TFNoBadWordsLogitsProcessor(bad_words_ids=[[1]], eos_token_id=eos_token_id)
|
||||
if use_xla:
|
||||
min_dist_proc = tf.function(min_dist_proc, jit_compile=True)
|
||||
temp_dist_warp = tf.function(temp_dist_warp, jit_compile=True)
|
||||
rep_penalty_proc = tf.function(rep_penalty_proc, jit_compile=True)
|
||||
top_k_warp = tf.function(top_k_warp, jit_compile=True)
|
||||
top_p_warp = tf.function(top_p_warp, jit_compile=True)
|
||||
# no_repeat_proc = tf.function(no_repeat_proc, jit_compile=True)
|
||||
no_bad_words_dist_proc = tf.function(no_bad_words_dist_proc, jit_compile=True)
|
||||
|
||||
# no processor list
|
||||
scores = min_dist_proc(input_ids, scores, cur_len)
|
||||
scores = temp_dist_warp(input_ids, scores)
|
||||
scores = temp_dist_warp(input_ids, scores, cur_len)
|
||||
scores = rep_penalty_proc(input_ids, scores, cur_len)
|
||||
scores = top_k_warp(input_ids, scores)
|
||||
scores = top_p_warp(input_ids, scores)
|
||||
scores = no_repeat_proc(input_ids, scores, cur_len)
|
||||
scores = top_k_warp(input_ids, scores, cur_len)
|
||||
scores = top_p_warp(input_ids, scores, cur_len)
|
||||
# scores = no_repeat_proc(input_ids, scores, cur_len)
|
||||
scores = no_bad_words_dist_proc(input_ids, scores, cur_len)
|
||||
|
||||
# with processor list
|
||||
@ -361,11 +377,11 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
||||
rep_penalty_proc,
|
||||
top_k_warp,
|
||||
top_p_warp,
|
||||
no_repeat_proc,
|
||||
# no_repeat_proc,
|
||||
no_bad_words_dist_proc,
|
||||
]
|
||||
)
|
||||
scores_comp = processor(input_ids, scores_comp, cur_len=cur_len)
|
||||
scores_comp = processor(input_ids, scores_comp, cur_len)
|
||||
|
||||
# remove inf
|
||||
scores = tf.where(tf.math.is_inf(scores), -1e9, scores)
|
||||
|
Loading…
Reference in New Issue
Block a user