Adding FlaxNoRepeatNGramLogitsProcessor (#29677)

* fix issue with logit processor in beam search in Flax

* adding FlaxNoRepeatNGramLogitsProcessor class + unit test

* style correction and code verification

* add FlaxNoRepeatNGramLogitsProcessor to the test_processor_list and test_processor_list_jitted tests

* fix an issue where ngrams are banned only if they appear ==1 time + update description of get_previous_ngrams

* replace non-jit compatible masking of ngrams that are not yet generated with jittable version

* Revert "fix issue with logit processor in beam search in Flax"

This reverts commit 09b70d7e4d.

* add FlaxNoRepeatNGramLogitsProcessor to _get_logits_processor

* change the method of casting to boolean of banned tokens indices

* fix code style

* remove some useless operations + significantly faster computation of update indices using jax.lax.fori_loop

* remove useless loop iterations

* set some variables that were calculated and used multiple times

* fix format
This commit is contained in:
théo gigant 2024-04-02 11:39:33 +02:00 committed by GitHub
parent 33288ff150
commit fed27ffc7e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 135 additions and 2 deletions

View File

@ -162,6 +162,7 @@ else:
"FlaxTopKLogitsWarper", "FlaxTopKLogitsWarper",
"FlaxTopPLogitsWarper", "FlaxTopPLogitsWarper",
"FlaxWhisperTimeStampLogitsProcessor", "FlaxWhisperTimeStampLogitsProcessor",
"FlaxNoRepeatNGramLogitsProcessor",
] ]
_import_structure["flax_utils"] = [ _import_structure["flax_utils"] = [
"FlaxGenerationMixin", "FlaxGenerationMixin",
@ -294,6 +295,7 @@ if TYPE_CHECKING:
FlaxLogitsProcessorList, FlaxLogitsProcessorList,
FlaxLogitsWarper, FlaxLogitsWarper,
FlaxMinLengthLogitsProcessor, FlaxMinLengthLogitsProcessor,
FlaxNoRepeatNGramLogitsProcessor,
FlaxSuppressTokensAtBeginLogitsProcessor, FlaxSuppressTokensAtBeginLogitsProcessor,
FlaxSuppressTokensLogitsProcessor, FlaxSuppressTokensLogitsProcessor,
FlaxTemperatureLogitsWarper, FlaxTemperatureLogitsWarper,

View File

@ -18,6 +18,7 @@ import inspect
import jax import jax
import jax.lax as lax import jax.lax as lax
import jax.numpy as jnp import jax.numpy as jnp
from jax.experimental import sparse
from ..utils import add_start_docstrings from ..utils import add_start_docstrings
from ..utils.logging import get_logger from ..utils.logging import get_logger
@ -455,3 +456,89 @@ class FlaxWhisperTimeStampLogitsProcessor(FlaxLogitsProcessor):
scores = jax.vmap(handle_cumulative_probs)(logprobs, scores) scores = jax.vmap(handle_cumulative_probs)(logprobs, scores)
return scores return scores
class FlaxNoRepeatNGramLogitsProcessor(FlaxLogitsProcessor):
r"""
[`FlaxLogitsProcessor`] that enforces no repetition of n-grams. See
[Fairseq](https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345).
Args:
ngram_size (`int`):
All ngrams of size `ngram_size` can only occur once.
"""
def __init__(self, ngram_size: int):
if not isinstance(ngram_size, int) or ngram_size <= 0:
raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}")
self.ngram_size = ngram_size
def get_previous_ngrams(self, input_ids: jnp.ndarray, vocab_size: int, cur_len: int):
"""
get a matrix of size (batch_size,) + (vocab_size,)*n (for n-grams) that
represent the n-grams that occured previously.
The BCOO representation allow to store only the few non-zero entries, instead of the full (huge) matrix
"""
batch_size, seq_len = input_ids.shape
# number of n-grams in the whole sequence
seq_ngrams = seq_len - (self.ngram_size - 1)
# number of n-grams in the currently generated sequence
cur_ngrams = cur_len - (self.ngram_size - 1)
def body_fun(i, val):
b = i % batch_size
pos = i // batch_size
return val.at[i].set(
jnp.array(
[
b,
]
+ [jnp.array(input_ids)[b, pos + j] for j in range(self.ngram_size)]
)
)
shape = (batch_size * seq_ngrams, self.ngram_size + 1)
all_update_indices = jax.lax.fori_loop(
0, batch_size * cur_ngrams, body_fun, jnp.zeros(shape, dtype=input_ids.dtype)
)
# ignore the n-grams not yet generated
data = (jnp.arange(batch_size * seq_ngrams) < batch_size * cur_ngrams).astype("float32")
return sparse.BCOO((data, all_update_indices), shape=(batch_size,) + (vocab_size,) * self.ngram_size)
def get_banned_tokens_mask(self, latest_tokens: jnp.ndarray, previous_ngrams) -> jnp.ndarray:
"""
Determines which tokens must be banned given latest tokens and the previously seen
ngrams.
"""
@sparse.sparsify
@jax.vmap
def inner_fn(latest_tokens, previous_ngrams):
return previous_ngrams[tuple(latest_tokens)]
return sparse.bcoo_todense(inner_fn(latest_tokens, previous_ngrams))
def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
def true_fn():
_, vocab_size = scores.shape
# store the previously seen n-grams
previous_ngrams = self.get_previous_ngrams(input_ids, vocab_size, cur_len)
# get the n-1 last tokens that prefix the n-gram being generated
latest_tokens = jnp.zeros((input_ids.shape[0], self.ngram_size - 1), dtype=input_ids.dtype)
latest_tokens = jax.lax.dynamic_update_slice(
latest_tokens,
jax.lax.dynamic_slice(
input_ids, (0, cur_len - (self.ngram_size - 1)), (input_ids.shape[0], (self.ngram_size - 1))
),
(0, 0),
)
# compute the banned tokens, ie all the tokens that when added to the latest tokens lead to a n-gram that was previously generated
banned_tokens_indices_mask = self.get_banned_tokens_mask(latest_tokens, previous_ngrams).astype("bool")
return jnp.where(banned_tokens_indices_mask, -float("inf"), scores)
output = jax.lax.cond((cur_len >= self.ngram_size - 1), true_fn, lambda: scores)
return output

View File

@ -40,6 +40,7 @@ from .flax_logits_process import (
FlaxForceTokensLogitsProcessor, FlaxForceTokensLogitsProcessor,
FlaxLogitsProcessorList, FlaxLogitsProcessorList,
FlaxMinLengthLogitsProcessor, FlaxMinLengthLogitsProcessor,
FlaxNoRepeatNGramLogitsProcessor,
FlaxSuppressTokensAtBeginLogitsProcessor, FlaxSuppressTokensAtBeginLogitsProcessor,
FlaxSuppressTokensLogitsProcessor, FlaxSuppressTokensLogitsProcessor,
FlaxTemperatureLogitsWarper, FlaxTemperatureLogitsWarper,
@ -534,6 +535,8 @@ class FlaxGenerationMixin:
[input_ids_seq_length + i[0] - 1, i[1]] for i in generation_config.forced_decoder_ids [input_ids_seq_length + i[0] - 1, i[1]] for i in generation_config.forced_decoder_ids
] ]
processors.append(FlaxForceTokensLogitsProcessor(forced_decoder_ids)) processors.append(FlaxForceTokensLogitsProcessor(forced_decoder_ids))
if generation_config.no_repeat_ngram_size is not None and generation_config.no_repeat_ngram_size > 0:
processors.append(FlaxNoRepeatNGramLogitsProcessor(generation_config.no_repeat_ngram_size))
processors = self._merge_criteria_processor_list(processors, logits_processor) processors = self._merge_criteria_processor_list(processors, logits_processor)
return processors return processors

View File

@ -33,6 +33,7 @@ if is_flax_available():
FlaxForcedEOSTokenLogitsProcessor, FlaxForcedEOSTokenLogitsProcessor,
FlaxLogitsProcessorList, FlaxLogitsProcessorList,
FlaxMinLengthLogitsProcessor, FlaxMinLengthLogitsProcessor,
FlaxNoRepeatNGramLogitsProcessor,
FlaxTemperatureLogitsWarper, FlaxTemperatureLogitsWarper,
FlaxTopKLogitsWarper, FlaxTopKLogitsWarper,
FlaxTopPLogitsWarper, FlaxTopPLogitsWarper,
@ -197,6 +198,26 @@ class LogitsProcessorTest(unittest.TestCase):
scores = logits_processor(input_ids, scores, cur_len=cur_len) scores = logits_processor(input_ids, scores, cur_len=cur_len)
self.assertFalse(jnp.isinf(scores).any()) self.assertFalse(jnp.isinf(scores).any())
def test_no_repeat_ngram_dist_processor(self):
vocab_size = 3
batch_size = 2
cur_len = 4
input_ids = np.array([[1, 1, 2, 1], [0, 1, 0, 1]], dtype="i4")
scores = self._get_uniform_logits(batch_size, vocab_size)
no_repeat_proc_2_gram = FlaxNoRepeatNGramLogitsProcessor(2)
no_repeat_proc_3_gram = FlaxNoRepeatNGramLogitsProcessor(3)
filtered_scores_2_gram = no_repeat_proc_2_gram(input_ids, scores, cur_len=cur_len)
filtered_scores_3_gram = no_repeat_proc_3_gram(input_ids, scores, cur_len=cur_len)
# 2-gram would forbid 2nd and 3rd token (1,2) at 1st batch and 1st token (0) at 2nd batch
self.assertListEqual(jnp.isinf(filtered_scores_2_gram).tolist(), [[False, True, True], [True, False, False]])
# 3-gram would forbid no token at 1st batch and 1st token (0) at 2nd batch
self.assertListEqual(jnp.isinf(filtered_scores_3_gram).tolist(), [[False, False, False], [True, False, False]])
def test_processor_list(self): def test_processor_list(self):
batch_size = 4 batch_size = 4
sequence_length = 10 sequence_length = 10
@ -216,6 +237,7 @@ class LogitsProcessorTest(unittest.TestCase):
temp_dist_warp = FlaxTemperatureLogitsWarper(temperature=0.5) temp_dist_warp = FlaxTemperatureLogitsWarper(temperature=0.5)
top_k_warp = FlaxTopKLogitsWarper(3) top_k_warp = FlaxTopKLogitsWarper(3)
top_p_warp = FlaxTopPLogitsWarper(0.8) top_p_warp = FlaxTopPLogitsWarper(0.8)
no_repeat_proc = FlaxNoRepeatNGramLogitsProcessor(2)
# instantiate all logits processors # instantiate all logits processors
min_dist_proc = FlaxMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id) min_dist_proc = FlaxMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
@ -231,10 +253,19 @@ class LogitsProcessorTest(unittest.TestCase):
scores = min_dist_proc(input_ids, scores, cur_len=cur_len) scores = min_dist_proc(input_ids, scores, cur_len=cur_len)
scores = bos_dist_proc(input_ids, scores, cur_len=cur_len) scores = bos_dist_proc(input_ids, scores, cur_len=cur_len)
scores = eos_dist_proc(input_ids, scores, cur_len=cur_len) scores = eos_dist_proc(input_ids, scores, cur_len=cur_len)
scores = no_repeat_proc(input_ids, scores, cur_len=cur_len)
# with processor list # with processor list
processor = FlaxLogitsProcessorList( processor = FlaxLogitsProcessorList(
[temp_dist_warp, top_k_warp, top_p_warp, min_dist_proc, bos_dist_proc, eos_dist_proc] [
temp_dist_warp,
top_k_warp,
top_p_warp,
min_dist_proc,
bos_dist_proc,
eos_dist_proc,
no_repeat_proc,
]
) )
scores_comp = processor(input_ids, scores_comp, cur_len=cur_len) scores_comp = processor(input_ids, scores_comp, cur_len=cur_len)
@ -263,6 +294,7 @@ class LogitsProcessorTest(unittest.TestCase):
temp_dist_warp = FlaxTemperatureLogitsWarper(temperature=0.5) temp_dist_warp = FlaxTemperatureLogitsWarper(temperature=0.5)
top_k_warp = FlaxTopKLogitsWarper(3) top_k_warp = FlaxTopKLogitsWarper(3)
top_p_warp = FlaxTopPLogitsWarper(0.8) top_p_warp = FlaxTopPLogitsWarper(0.8)
no_repeat_proc = FlaxNoRepeatNGramLogitsProcessor(2)
# instantiate all logits processors # instantiate all logits processors
min_dist_proc = FlaxMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id) min_dist_proc = FlaxMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
@ -279,12 +311,21 @@ class LogitsProcessorTest(unittest.TestCase):
scores = min_dist_proc(input_ids, scores, cur_len=cur_len) scores = min_dist_proc(input_ids, scores, cur_len=cur_len)
scores = bos_dist_proc(input_ids, scores, cur_len=cur_len) scores = bos_dist_proc(input_ids, scores, cur_len=cur_len)
scores = eos_dist_proc(input_ids, scores, cur_len=cur_len) scores = eos_dist_proc(input_ids, scores, cur_len=cur_len)
scores = no_repeat_proc(input_ids, scores, cur_len=cur_len)
return scores return scores
# with processor list # with processor list
def run_processor_list(input_ids, scores, cur_len): def run_processor_list(input_ids, scores, cur_len):
processor = FlaxLogitsProcessorList( processor = FlaxLogitsProcessorList(
[temp_dist_warp, top_k_warp, top_p_warp, min_dist_proc, bos_dist_proc, eos_dist_proc] [
temp_dist_warp,
top_k_warp,
top_p_warp,
min_dist_proc,
bos_dist_proc,
eos_dist_proc,
no_repeat_proc,
]
) )
scores = processor(input_ids, scores, cur_len=cur_len) scores = processor(input_ids, scores, cur_len=cur_len)
return scores return scores