TF: XLA bad words logits processor and list of processors (#16974)

This commit is contained in:
Joao Gante 2022-04-29 15:54:58 +01:00 committed by GitHub
parent 57e6464ac9
commit fb0ae12947
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 115 additions and 74 deletions

View File

@ -14,7 +14,7 @@
# limitations under the License.
import inspect
from typing import List
from typing import List, Tuple
import numpy as np
import tensorflow as tf
@ -38,7 +38,10 @@ TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
[What are input IDs?](../glossary#input-ids)
scores (`tf.Tensor` of shape `(batch_size, config.vocab_size)`):
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam
search or log softmax for each vocabulary token when using beam search
search or log softmax for each vocabulary token when using beam search.
cur_len (`int`):
The current length of valid input sequence tokens. In the TF implementation, the input_ids' sequence length
is the maximum length generate can produce, and we need to know which of its tokens are valid.
kwargs:
Additional logits processor specific kwargs.
@ -51,7 +54,7 @@ class TFLogitsProcessor:
"""Abstract base class for all logit processors that can be applied during generation."""
@add_start_docstrings(TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
"""TF method for processing logits."""
raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
@ -62,7 +65,7 @@ class TFLogitsWarper:
"""Abstract base class for all logit warpers that can be applied during generation with multinomial sampling."""
@add_start_docstrings(TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
"""TF method for warping logits."""
raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
@ -77,18 +80,18 @@ class TFLogitsProcessorList(list):
"""
@add_start_docstrings(TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, **kwargs) -> tf.Tensor:
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int, **kwargs) -> tf.Tensor:
for processor in self:
function_args = inspect.signature(processor.__call__).parameters
if len(function_args) > 2:
if len(function_args) > 3:
if not all(arg in kwargs for arg in list(function_args.keys())[2:]):
raise ValueError(
f"Make sure that all the required parameters: {list(function_args.keys())} for "
f"{processor.__class__} are passed to the logits processor."
)
scores = processor(input_ids, scores, **kwargs)
scores = processor(input_ids, scores, cur_len, **kwargs)
else:
scores = processor(input_ids, scores)
scores = processor(input_ids, scores, cur_len)
return scores
@ -107,7 +110,7 @@ class TFTemperatureLogitsWarper(TFLogitsWarper):
self.temperature = temperature
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
scores = scores / self.temperature
return scores
@ -133,7 +136,7 @@ class TFTopKLogitsWarper(TFLogitsWarper):
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
top_k = min(max(self.top_k, self.min_tokens_to_keep), scores.shape[-1]) # Safety check
# Boolean mask containing all tokens with a probability less than the last token of the top-k
indices_to_remove = scores < tf.math.top_k(scores, k=top_k)[0][..., -1:]
@ -163,7 +166,7 @@ class TFTopPLogitsWarper(TFLogitsWarper):
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
topk_scores, topk_indices = tf.math.top_k(scores, scores.shape[-1])
mask_scores = tf.fill(scores.shape, self.filter_value)
@ -305,58 +308,75 @@ class TFNoBadWordsLogitsProcessor(TFLogitsProcessor):
f"Each list in `bad_words_ids` has to be a list of positive integers, but is {bad_words_ids}."
)
self.bad_words_ids = bad_words_ids
# stores the information about bad words in three tensors:
# 1. a rectangular tensor with the forbidden sequences (padded with `-1`), for full data comparisons
self.bad_word_seqs_ids = tf.ragged.constant(bad_words_ids).to_tensor(default_value=-1)
# 2. a tensor with the unpadded length of each forbidden sequence, for quick length comparisons
bad_word_seqs_len = [len(bad_words) for bad_words in bad_words_ids]
if any([word_len == 0 for word_len in bad_word_seqs_len]):
raise ValueError(f"Banned words token sequences {bad_words_ids} cannot have an empty list")
self.bad_word_seqs_len = tf.convert_to_tensor(bad_word_seqs_len, dtype=tf.int32)
# 3. a tensor containing the last token for each sequence, for easy access to the tokens that may be banned
self.seq_forbidden_tokens = tf.convert_to_tensor([bad_words[-1] for bad_words in bad_words_ids])
def calc_banned_bad_words_ids(self, prev_input_ids):
banned_tokens = []
def _calc_row_banned_bad_tokens(self, row_input_ids: tf.Tensor) -> tf.Tensor:
def _tokens_match(bad_word_seq_number):
def _len_one():
# If the bad sequence only has one token, always mask it
return tf.cond(
tf.math.equal(self.bad_word_seqs_len[bad_word_seq_number], 1),
lambda: tf.ones((), dtype=tf.bool),
_len_greater_than_cur_len,
)
def _tokens_match(prev_tokens, tokens):
if len(tokens) == 0:
# if bad word tokens is just one token always ban it
return True
if len(tokens) > len(prev_tokens):
# if bad word tokens are longer than prev tokens they can't be equal
return False
def _len_greater_than_cur_len():
# Otherwise, if the bad sequence is longer than the current length they can't ever match
return tf.cond(
tf.math.greater(self.bad_word_seqs_len[bad_word_seq_number], row_input_ids.shape[0]),
lambda: tf.zeros((), dtype=tf.bool),
_match_found,
)
if prev_tokens[-len(tokens) :] == tokens:
# if tokens match
return True
else:
return False
def _match_found():
# Finaly, runs the actual comparison. Can only be called if the previous comparisons do not yield
# an answer (otherwise we get indexing exceptions)
compare_len = self.bad_word_seqs_len[bad_word_seq_number] - 1
return tf.cond(
tf.math.reduce_all(
tf.math.equal(
row_input_ids[-compare_len:], self.bad_word_seqs_ids[bad_word_seq_number, :compare_len]
)
),
lambda: tf.ones((), dtype=tf.bool),
lambda: tf.zeros((), dtype=tf.bool),
)
for prev_input_ids_slice in prev_input_ids:
banned_tokens_slice = []
match = _len_one()
return match
for banned_token_seq in self.bad_words_ids:
assert (
len(banned_token_seq) > 0
), f"Banned words token sequences {self.bad_words_ids} cannot have an empty list"
if _tokens_match(prev_input_ids_slice.numpy().tolist(), banned_token_seq[:-1]) is False:
# if tokens do not match continue
continue
banned_tokens_slice.append(banned_token_seq[-1])
banned_tokens.append(banned_tokens_slice)
return banned_tokens
# Compares the current row against all bad word sequences, obtaining a mask with the matches.
match_mask = tf.map_fn(_tokens_match, tf.range(self.bad_word_seqs_ids.shape[0]), fn_output_signature=tf.bool)
row_banned_tokens = self.seq_forbidden_tokens[match_mask]
return row_banned_tokens
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
vocab_size = scores.shape[-1]
# calculate a list of banned tokens according to bad words
banned_tokens = self.calc_banned_bad_words_ids(input_ids[:, :cur_len])
banned_tokens_indices_mask = []
for banned_tokens_slice in banned_tokens:
banned_tokens_indices_mask.append(
[True if token in banned_tokens_slice else False for token in range(vocab_size)]
# We want to mask some banned tokens, at a score level. Since the banned tokens depend on the previous
# `input_ids`, they may have a different length for each row, and they may even be empty for some rows.
# To remain simple and XLA-compatible, we work on a per-row fashion.
# TODO (Joao): this function might trigger XLA retracing as `cur_len` increases. Fix it if it becomes
# a frequent choke point. (make `cur_len` a tensor?)
def _get_row_updated_score(row_inputs: Tuple[tf.Tensor]) -> tf.Tensor:
row_input_ids, row_score = row_inputs
banned_tokens = self._calc_row_banned_bad_tokens(row_input_ids[:cur_len])
banned_tokens_mask = tf.scatter_nd(
indices=tf.expand_dims(banned_tokens, axis=-1),
updates=tf.ones_like(banned_tokens, dtype=tf.bool),
shape=row_score.shape,
)
row_score = tf.where(banned_tokens_mask, -float("inf"), row_score)
return row_score
scores = tf.where(tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf"), scores)
scores = tf.map_fn(_get_row_updated_score, (input_ids, scores), fn_output_signature=tf.float32)
return scores
@ -401,6 +421,11 @@ class TFNoRepeatNGramLogitsProcessor(TFLogitsProcessor):
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
# TODO (joao): enable XLA on this logits processor. See discussion and attempts in
# https://github.com/huggingface/transformers/pull/16974
if not tf.executing_eagerly():
raise NotImplementedError("TFNoRepeatNGramLogitsProcessor is only implemented for eager execution.")
batch_size, vocab_size = scores.shape
banned_tokens = self.calc_banned_ngram_tokens(input_ids, batch_size, cur_len)

View File

@ -2030,7 +2030,7 @@ class TFGenerationMixin:
if not use_xla:
input_ids = tf.reshape(generated.concat(), (-1, batch_size))
input_ids = tf.transpose(input_ids[: current_pos[0]])
next_tokens_scores = logits_processor(input_ids, next_token_logits, cur_len=current_pos[0])
next_tokens_scores = logits_processor(input_ids, next_token_logits, current_pos[0])
# argmax
next_tokens = tf.argmax(next_tokens_scores, axis=-1, output_type=tf.int32)
@ -2301,8 +2301,8 @@ class TFGenerationMixin:
if not use_xla:
input_ids = tf.reshape(generated.concat(), (-1, batch_size))
input_ids = tf.transpose(input_ids[:cur_len])
next_tokens_scores = logits_processor(input_ids, next_token_logits, cur_len=cur_len)
next_tokens_scores = logits_warper(input_ids, next_tokens_scores)
next_tokens_scores = logits_processor(input_ids, next_token_logits, cur_len)
next_tokens_scores = logits_warper(input_ids, next_tokens_scores, cur_len)
# sample
if seed is not None:
@ -2726,7 +2726,7 @@ class TFGenerationMixin:
# add new logprobs to existing running logprobs scores.
log_probs = tf.nn.log_softmax(logits)
log_probs = logits_processor(
flatten_beam_dim(running_sequences_seq_last), flatten_beam_dim(log_probs), cur_len=cur_len
flatten_beam_dim(running_sequences_seq_last), flatten_beam_dim(log_probs), cur_len
)
log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams)
log_probs = log_probs + tf.expand_dims(running_scores, axis=2)

View File

@ -75,6 +75,7 @@ class TFLogitsProcessorTest(unittest.TestCase):
@parameterized.expand([(False,), (True,)])
def test_temperature_dist_warper(self, use_xla):
input_ids = None
cur_len = None
length = 20
scores = self._get_uniform_logits(batch_size=2, length=length)
@ -94,8 +95,8 @@ class TFLogitsProcessorTest(unittest.TestCase):
temp_dist_warper_sharper = tf.function(temp_dist_warper_sharper, jit_compile=True)
temp_dist_warper_smoother = tf.function(temp_dist_warper_smoother, jit_compile=True)
warped_prob_sharp = tf.nn.softmax(temp_dist_warper_sharper(input_ids, tf.identity(scores)), axis=-1)
warped_prob_smooth = tf.nn.softmax(temp_dist_warper_smoother(input_ids, tf.identity(scores)), axis=-1)
warped_prob_sharp = tf.nn.softmax(temp_dist_warper_sharper(input_ids, tf.identity(scores), cur_len), axis=-1)
warped_prob_smooth = tf.nn.softmax(temp_dist_warper_smoother(input_ids, tf.identity(scores), cur_len), axis=-1)
# uniform distribution stays uniform
tf.debugging.assert_near(probs[0, :], warped_prob_sharp[0, :], atol=1e-3)
@ -142,6 +143,7 @@ class TFLogitsProcessorTest(unittest.TestCase):
@parameterized.expand([(False,), (True,)])
def test_top_k_dist_warper(self, use_xla):
input_ids = None
cur_len = None
vocab_size = 10
batch_size = 2
@ -153,7 +155,7 @@ class TFLogitsProcessorTest(unittest.TestCase):
if use_xla:
top_k_warp = tf.function(top_k_warp, jit_compile=True)
scores = top_k_warp(input_ids, ramp_logits)
scores = top_k_warp(input_ids, ramp_logits, cur_len)
# check that correct tokens are filtered
self.assertListEqual(tf.math.is_inf(scores[0]).numpy().tolist(), 7 * [True] + 3 * [False])
@ -167,12 +169,12 @@ class TFLogitsProcessorTest(unittest.TestCase):
if use_xla:
top_k_warp_safety_check = tf.function(top_k_warp_safety_check, jit_compile=True)
scores = top_k_warp_safety_check(input_ids, logits)
scores = top_k_warp_safety_check(input_ids, logits, cur_len)
# uniform dist is not changed
self.assertListEqual(tf.math.reduce_sum(tf.where(scores == 0.0, 1, 0), axis=-1).numpy().tolist(), [0, 0])
ramp_logits = np.broadcast_to(np.arange(length, dtype=np.float32), (batch_size, length)).copy()
scores = top_k_warp_safety_check(input_ids, ramp_logits)
scores = top_k_warp_safety_check(input_ids, ramp_logits, cur_len)
# min_tokens overwrites k: 3 tokens are kept => 2 tokens are nullified
self.assertListEqual(tf.math.reduce_sum(tf.where(scores == 0.0, 1, 0), axis=-1).numpy().tolist(), [2, 2])
@ -180,6 +182,7 @@ class TFLogitsProcessorTest(unittest.TestCase):
@parameterized.expand([(False,), (True,)])
def test_top_p_dist_warper(self, use_xla):
input_ids = None
cur_len = None
vocab_size = 10
batch_size = 2
@ -189,7 +192,7 @@ class TFLogitsProcessorTest(unittest.TestCase):
top_p_warp = TFTopPLogitsWarper(0.7)
if use_xla:
top_p_warp = tf.function(top_p_warp, jit_compile=True)
filtered_dist = tf.exp(top_p_warp(input_ids, dist))
filtered_dist = tf.exp(top_p_warp(input_ids, dist, cur_len))
# dist should be filtered to keep min num values so that sum is >= 0.7
# exp (-inf) => 0
@ -208,7 +211,7 @@ class TFLogitsProcessorTest(unittest.TestCase):
top_p_warp = TFTopPLogitsWarper(0.9, min_tokens_to_keep=2, filter_value=0.0)
if use_xla:
top_p_warp = tf.function(top_p_warp, jit_compile=True)
filtered_dist = top_p_warp(input_ids, ramp_logits)
filtered_dist = top_p_warp(input_ids, ramp_logits, cur_len)
# first batch should keep three tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps
# 2.
@ -242,7 +245,8 @@ class TFLogitsProcessorTest(unittest.TestCase):
tf.math.is_inf(filtered_scores_3_gram).numpy().tolist(), [[False, False, False], [True, False, False]]
)
def test_no_bad_words_dist_processor(self):
@parameterized.expand([(False,), (True,)])
def test_no_bad_words_dist_processor(self, use_xla):
vocab_size = 5
batch_size = 2
eos_token_id = 4
@ -255,6 +259,8 @@ class TFLogitsProcessorTest(unittest.TestCase):
scores = self._get_uniform_logits(batch_size, vocab_size)
no_bad_words_dist_proc = TFNoBadWordsLogitsProcessor(bad_words_ids=bad_word_tokens, eos_token_id=eos_token_id)
if use_xla:
no_bad_words_dist_proc = tf.function(no_bad_words_dist_proc, jit_compile=True)
filtered_scores = no_bad_words_dist_proc(input_ids, tf.identity(scores), cur_len)
@ -322,7 +328,9 @@ class TFLogitsProcessorTest(unittest.TestCase):
scores = logits_processor(input_ids, scores, cur_len)
self.assertFalse(tf.math.reduce_any(tf.math.is_inf((scores))))
def test_processor_list(self):
@parameterized.expand([(False,), (True,)])
def test_processor_list(self, use_xla):
# TODO (Joao): reintroduce TFNoRepeatNGramLogitsProcessor when it gets compatible with XLA
batch_size = 4
cur_len = 10
vocab_size = 15
@ -341,16 +349,24 @@ class TFLogitsProcessorTest(unittest.TestCase):
rep_penalty_proc = TFRepetitionPenaltyLogitsProcessor(penalty=2.0)
top_k_warp = TFTopKLogitsWarper(3)
top_p_warp = TFTopPLogitsWarper(0.8)
no_repeat_proc = TFNoRepeatNGramLogitsProcessor(2)
# no_repeat_proc = TFNoRepeatNGramLogitsProcessor(2)
no_bad_words_dist_proc = TFNoBadWordsLogitsProcessor(bad_words_ids=[[1]], eos_token_id=eos_token_id)
if use_xla:
min_dist_proc = tf.function(min_dist_proc, jit_compile=True)
temp_dist_warp = tf.function(temp_dist_warp, jit_compile=True)
rep_penalty_proc = tf.function(rep_penalty_proc, jit_compile=True)
top_k_warp = tf.function(top_k_warp, jit_compile=True)
top_p_warp = tf.function(top_p_warp, jit_compile=True)
# no_repeat_proc = tf.function(no_repeat_proc, jit_compile=True)
no_bad_words_dist_proc = tf.function(no_bad_words_dist_proc, jit_compile=True)
# no processor list
scores = min_dist_proc(input_ids, scores, cur_len)
scores = temp_dist_warp(input_ids, scores)
scores = temp_dist_warp(input_ids, scores, cur_len)
scores = rep_penalty_proc(input_ids, scores, cur_len)
scores = top_k_warp(input_ids, scores)
scores = top_p_warp(input_ids, scores)
scores = no_repeat_proc(input_ids, scores, cur_len)
scores = top_k_warp(input_ids, scores, cur_len)
scores = top_p_warp(input_ids, scores, cur_len)
# scores = no_repeat_proc(input_ids, scores, cur_len)
scores = no_bad_words_dist_proc(input_ids, scores, cur_len)
# with processor list
@ -361,11 +377,11 @@ class TFLogitsProcessorTest(unittest.TestCase):
rep_penalty_proc,
top_k_warp,
top_p_warp,
no_repeat_proc,
# no_repeat_proc,
no_bad_words_dist_proc,
]
)
scores_comp = processor(input_ids, scores_comp, cur_len=cur_len)
scores_comp = processor(input_ids, scores_comp, cur_len)
# remove inf
scores = tf.where(tf.math.is_inf(scores), -1e9, scores)