mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Adding FlaxNoRepeatNGramLogitsProcessor (#29677)
* fix issue with logit processor in beam search in Flax
* adding FlaxNoRepeatNGramLogitsProcessor class + unit test
* style correction and code verification
* add FlaxNoRepeatNGramLogitsProcessor to the test_processor_list and test_processor_list_jitted tests
* fix an issue where ngrams are banned only if they appear ==1 time + update description of get_previous_ngrams
* replace non-jit compatible masking of ngrams that are not yet generated with jittable version
* Revert "fix issue with logit processor in beam search in Flax"
This reverts commit 09b70d7e4d
.
* add FlaxNoRepeatNGramLogitsProcessor to _get_logits_processor
* change the method of casting to boolean of banned tokens indices
* fix code style
* remove some useless operations + significantly faster computation of update indices using jax.lax.fori_loop
* remove useless loop iterations
* set some variables that were calculated and used multiple times
* fix format
This commit is contained in:
parent
33288ff150
commit
fed27ffc7e
@ -162,6 +162,7 @@ else:
|
|||||||
"FlaxTopKLogitsWarper",
|
"FlaxTopKLogitsWarper",
|
||||||
"FlaxTopPLogitsWarper",
|
"FlaxTopPLogitsWarper",
|
||||||
"FlaxWhisperTimeStampLogitsProcessor",
|
"FlaxWhisperTimeStampLogitsProcessor",
|
||||||
|
"FlaxNoRepeatNGramLogitsProcessor",
|
||||||
]
|
]
|
||||||
_import_structure["flax_utils"] = [
|
_import_structure["flax_utils"] = [
|
||||||
"FlaxGenerationMixin",
|
"FlaxGenerationMixin",
|
||||||
@ -294,6 +295,7 @@ if TYPE_CHECKING:
|
|||||||
FlaxLogitsProcessorList,
|
FlaxLogitsProcessorList,
|
||||||
FlaxLogitsWarper,
|
FlaxLogitsWarper,
|
||||||
FlaxMinLengthLogitsProcessor,
|
FlaxMinLengthLogitsProcessor,
|
||||||
|
FlaxNoRepeatNGramLogitsProcessor,
|
||||||
FlaxSuppressTokensAtBeginLogitsProcessor,
|
FlaxSuppressTokensAtBeginLogitsProcessor,
|
||||||
FlaxSuppressTokensLogitsProcessor,
|
FlaxSuppressTokensLogitsProcessor,
|
||||||
FlaxTemperatureLogitsWarper,
|
FlaxTemperatureLogitsWarper,
|
||||||
|
@ -18,6 +18,7 @@ import inspect
|
|||||||
import jax
|
import jax
|
||||||
import jax.lax as lax
|
import jax.lax as lax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
from jax.experimental import sparse
|
||||||
|
|
||||||
from ..utils import add_start_docstrings
|
from ..utils import add_start_docstrings
|
||||||
from ..utils.logging import get_logger
|
from ..utils.logging import get_logger
|
||||||
@ -455,3 +456,89 @@ class FlaxWhisperTimeStampLogitsProcessor(FlaxLogitsProcessor):
|
|||||||
scores = jax.vmap(handle_cumulative_probs)(logprobs, scores)
|
scores = jax.vmap(handle_cumulative_probs)(logprobs, scores)
|
||||||
|
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxNoRepeatNGramLogitsProcessor(FlaxLogitsProcessor):
|
||||||
|
r"""
|
||||||
|
[`FlaxLogitsProcessor`] that enforces no repetition of n-grams. See
|
||||||
|
[Fairseq](https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ngram_size (`int`):
|
||||||
|
All ngrams of size `ngram_size` can only occur once.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, ngram_size: int):
|
||||||
|
if not isinstance(ngram_size, int) or ngram_size <= 0:
|
||||||
|
raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}")
|
||||||
|
self.ngram_size = ngram_size
|
||||||
|
|
||||||
|
def get_previous_ngrams(self, input_ids: jnp.ndarray, vocab_size: int, cur_len: int):
|
||||||
|
"""
|
||||||
|
get a matrix of size (batch_size,) + (vocab_size,)*n (for n-grams) that
|
||||||
|
represent the n-grams that occured previously.
|
||||||
|
The BCOO representation allow to store only the few non-zero entries, instead of the full (huge) matrix
|
||||||
|
"""
|
||||||
|
batch_size, seq_len = input_ids.shape
|
||||||
|
# number of n-grams in the whole sequence
|
||||||
|
seq_ngrams = seq_len - (self.ngram_size - 1)
|
||||||
|
# number of n-grams in the currently generated sequence
|
||||||
|
cur_ngrams = cur_len - (self.ngram_size - 1)
|
||||||
|
|
||||||
|
def body_fun(i, val):
|
||||||
|
b = i % batch_size
|
||||||
|
pos = i // batch_size
|
||||||
|
return val.at[i].set(
|
||||||
|
jnp.array(
|
||||||
|
[
|
||||||
|
b,
|
||||||
|
]
|
||||||
|
+ [jnp.array(input_ids)[b, pos + j] for j in range(self.ngram_size)]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
shape = (batch_size * seq_ngrams, self.ngram_size + 1)
|
||||||
|
all_update_indices = jax.lax.fori_loop(
|
||||||
|
0, batch_size * cur_ngrams, body_fun, jnp.zeros(shape, dtype=input_ids.dtype)
|
||||||
|
)
|
||||||
|
|
||||||
|
# ignore the n-grams not yet generated
|
||||||
|
data = (jnp.arange(batch_size * seq_ngrams) < batch_size * cur_ngrams).astype("float32")
|
||||||
|
|
||||||
|
return sparse.BCOO((data, all_update_indices), shape=(batch_size,) + (vocab_size,) * self.ngram_size)
|
||||||
|
|
||||||
|
def get_banned_tokens_mask(self, latest_tokens: jnp.ndarray, previous_ngrams) -> jnp.ndarray:
|
||||||
|
"""
|
||||||
|
Determines which tokens must be banned given latest tokens and the previously seen
|
||||||
|
ngrams.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@sparse.sparsify
|
||||||
|
@jax.vmap
|
||||||
|
def inner_fn(latest_tokens, previous_ngrams):
|
||||||
|
return previous_ngrams[tuple(latest_tokens)]
|
||||||
|
|
||||||
|
return sparse.bcoo_todense(inner_fn(latest_tokens, previous_ngrams))
|
||||||
|
|
||||||
|
def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
|
||||||
|
def true_fn():
|
||||||
|
_, vocab_size = scores.shape
|
||||||
|
# store the previously seen n-grams
|
||||||
|
previous_ngrams = self.get_previous_ngrams(input_ids, vocab_size, cur_len)
|
||||||
|
|
||||||
|
# get the n-1 last tokens that prefix the n-gram being generated
|
||||||
|
latest_tokens = jnp.zeros((input_ids.shape[0], self.ngram_size - 1), dtype=input_ids.dtype)
|
||||||
|
latest_tokens = jax.lax.dynamic_update_slice(
|
||||||
|
latest_tokens,
|
||||||
|
jax.lax.dynamic_slice(
|
||||||
|
input_ids, (0, cur_len - (self.ngram_size - 1)), (input_ids.shape[0], (self.ngram_size - 1))
|
||||||
|
),
|
||||||
|
(0, 0),
|
||||||
|
)
|
||||||
|
|
||||||
|
# compute the banned tokens, ie all the tokens that when added to the latest tokens lead to a n-gram that was previously generated
|
||||||
|
banned_tokens_indices_mask = self.get_banned_tokens_mask(latest_tokens, previous_ngrams).astype("bool")
|
||||||
|
return jnp.where(banned_tokens_indices_mask, -float("inf"), scores)
|
||||||
|
|
||||||
|
output = jax.lax.cond((cur_len >= self.ngram_size - 1), true_fn, lambda: scores)
|
||||||
|
return output
|
||||||
|
@ -40,6 +40,7 @@ from .flax_logits_process import (
|
|||||||
FlaxForceTokensLogitsProcessor,
|
FlaxForceTokensLogitsProcessor,
|
||||||
FlaxLogitsProcessorList,
|
FlaxLogitsProcessorList,
|
||||||
FlaxMinLengthLogitsProcessor,
|
FlaxMinLengthLogitsProcessor,
|
||||||
|
FlaxNoRepeatNGramLogitsProcessor,
|
||||||
FlaxSuppressTokensAtBeginLogitsProcessor,
|
FlaxSuppressTokensAtBeginLogitsProcessor,
|
||||||
FlaxSuppressTokensLogitsProcessor,
|
FlaxSuppressTokensLogitsProcessor,
|
||||||
FlaxTemperatureLogitsWarper,
|
FlaxTemperatureLogitsWarper,
|
||||||
@ -534,6 +535,8 @@ class FlaxGenerationMixin:
|
|||||||
[input_ids_seq_length + i[0] - 1, i[1]] for i in generation_config.forced_decoder_ids
|
[input_ids_seq_length + i[0] - 1, i[1]] for i in generation_config.forced_decoder_ids
|
||||||
]
|
]
|
||||||
processors.append(FlaxForceTokensLogitsProcessor(forced_decoder_ids))
|
processors.append(FlaxForceTokensLogitsProcessor(forced_decoder_ids))
|
||||||
|
if generation_config.no_repeat_ngram_size is not None and generation_config.no_repeat_ngram_size > 0:
|
||||||
|
processors.append(FlaxNoRepeatNGramLogitsProcessor(generation_config.no_repeat_ngram_size))
|
||||||
processors = self._merge_criteria_processor_list(processors, logits_processor)
|
processors = self._merge_criteria_processor_list(processors, logits_processor)
|
||||||
|
|
||||||
return processors
|
return processors
|
||||||
|
@ -33,6 +33,7 @@ if is_flax_available():
|
|||||||
FlaxForcedEOSTokenLogitsProcessor,
|
FlaxForcedEOSTokenLogitsProcessor,
|
||||||
FlaxLogitsProcessorList,
|
FlaxLogitsProcessorList,
|
||||||
FlaxMinLengthLogitsProcessor,
|
FlaxMinLengthLogitsProcessor,
|
||||||
|
FlaxNoRepeatNGramLogitsProcessor,
|
||||||
FlaxTemperatureLogitsWarper,
|
FlaxTemperatureLogitsWarper,
|
||||||
FlaxTopKLogitsWarper,
|
FlaxTopKLogitsWarper,
|
||||||
FlaxTopPLogitsWarper,
|
FlaxTopPLogitsWarper,
|
||||||
@ -197,6 +198,26 @@ class LogitsProcessorTest(unittest.TestCase):
|
|||||||
scores = logits_processor(input_ids, scores, cur_len=cur_len)
|
scores = logits_processor(input_ids, scores, cur_len=cur_len)
|
||||||
self.assertFalse(jnp.isinf(scores).any())
|
self.assertFalse(jnp.isinf(scores).any())
|
||||||
|
|
||||||
|
def test_no_repeat_ngram_dist_processor(self):
|
||||||
|
vocab_size = 3
|
||||||
|
batch_size = 2
|
||||||
|
|
||||||
|
cur_len = 4
|
||||||
|
input_ids = np.array([[1, 1, 2, 1], [0, 1, 0, 1]], dtype="i4")
|
||||||
|
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||||
|
|
||||||
|
no_repeat_proc_2_gram = FlaxNoRepeatNGramLogitsProcessor(2)
|
||||||
|
no_repeat_proc_3_gram = FlaxNoRepeatNGramLogitsProcessor(3)
|
||||||
|
|
||||||
|
filtered_scores_2_gram = no_repeat_proc_2_gram(input_ids, scores, cur_len=cur_len)
|
||||||
|
filtered_scores_3_gram = no_repeat_proc_3_gram(input_ids, scores, cur_len=cur_len)
|
||||||
|
|
||||||
|
# 2-gram would forbid 2nd and 3rd token (1,2) at 1st batch and 1st token (0) at 2nd batch
|
||||||
|
self.assertListEqual(jnp.isinf(filtered_scores_2_gram).tolist(), [[False, True, True], [True, False, False]])
|
||||||
|
|
||||||
|
# 3-gram would forbid no token at 1st batch and 1st token (0) at 2nd batch
|
||||||
|
self.assertListEqual(jnp.isinf(filtered_scores_3_gram).tolist(), [[False, False, False], [True, False, False]])
|
||||||
|
|
||||||
def test_processor_list(self):
|
def test_processor_list(self):
|
||||||
batch_size = 4
|
batch_size = 4
|
||||||
sequence_length = 10
|
sequence_length = 10
|
||||||
@ -216,6 +237,7 @@ class LogitsProcessorTest(unittest.TestCase):
|
|||||||
temp_dist_warp = FlaxTemperatureLogitsWarper(temperature=0.5)
|
temp_dist_warp = FlaxTemperatureLogitsWarper(temperature=0.5)
|
||||||
top_k_warp = FlaxTopKLogitsWarper(3)
|
top_k_warp = FlaxTopKLogitsWarper(3)
|
||||||
top_p_warp = FlaxTopPLogitsWarper(0.8)
|
top_p_warp = FlaxTopPLogitsWarper(0.8)
|
||||||
|
no_repeat_proc = FlaxNoRepeatNGramLogitsProcessor(2)
|
||||||
|
|
||||||
# instantiate all logits processors
|
# instantiate all logits processors
|
||||||
min_dist_proc = FlaxMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
|
min_dist_proc = FlaxMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
|
||||||
@ -231,10 +253,19 @@ class LogitsProcessorTest(unittest.TestCase):
|
|||||||
scores = min_dist_proc(input_ids, scores, cur_len=cur_len)
|
scores = min_dist_proc(input_ids, scores, cur_len=cur_len)
|
||||||
scores = bos_dist_proc(input_ids, scores, cur_len=cur_len)
|
scores = bos_dist_proc(input_ids, scores, cur_len=cur_len)
|
||||||
scores = eos_dist_proc(input_ids, scores, cur_len=cur_len)
|
scores = eos_dist_proc(input_ids, scores, cur_len=cur_len)
|
||||||
|
scores = no_repeat_proc(input_ids, scores, cur_len=cur_len)
|
||||||
|
|
||||||
# with processor list
|
# with processor list
|
||||||
processor = FlaxLogitsProcessorList(
|
processor = FlaxLogitsProcessorList(
|
||||||
[temp_dist_warp, top_k_warp, top_p_warp, min_dist_proc, bos_dist_proc, eos_dist_proc]
|
[
|
||||||
|
temp_dist_warp,
|
||||||
|
top_k_warp,
|
||||||
|
top_p_warp,
|
||||||
|
min_dist_proc,
|
||||||
|
bos_dist_proc,
|
||||||
|
eos_dist_proc,
|
||||||
|
no_repeat_proc,
|
||||||
|
]
|
||||||
)
|
)
|
||||||
scores_comp = processor(input_ids, scores_comp, cur_len=cur_len)
|
scores_comp = processor(input_ids, scores_comp, cur_len=cur_len)
|
||||||
|
|
||||||
@ -263,6 +294,7 @@ class LogitsProcessorTest(unittest.TestCase):
|
|||||||
temp_dist_warp = FlaxTemperatureLogitsWarper(temperature=0.5)
|
temp_dist_warp = FlaxTemperatureLogitsWarper(temperature=0.5)
|
||||||
top_k_warp = FlaxTopKLogitsWarper(3)
|
top_k_warp = FlaxTopKLogitsWarper(3)
|
||||||
top_p_warp = FlaxTopPLogitsWarper(0.8)
|
top_p_warp = FlaxTopPLogitsWarper(0.8)
|
||||||
|
no_repeat_proc = FlaxNoRepeatNGramLogitsProcessor(2)
|
||||||
|
|
||||||
# instantiate all logits processors
|
# instantiate all logits processors
|
||||||
min_dist_proc = FlaxMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
|
min_dist_proc = FlaxMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
|
||||||
@ -279,12 +311,21 @@ class LogitsProcessorTest(unittest.TestCase):
|
|||||||
scores = min_dist_proc(input_ids, scores, cur_len=cur_len)
|
scores = min_dist_proc(input_ids, scores, cur_len=cur_len)
|
||||||
scores = bos_dist_proc(input_ids, scores, cur_len=cur_len)
|
scores = bos_dist_proc(input_ids, scores, cur_len=cur_len)
|
||||||
scores = eos_dist_proc(input_ids, scores, cur_len=cur_len)
|
scores = eos_dist_proc(input_ids, scores, cur_len=cur_len)
|
||||||
|
scores = no_repeat_proc(input_ids, scores, cur_len=cur_len)
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
# with processor list
|
# with processor list
|
||||||
def run_processor_list(input_ids, scores, cur_len):
|
def run_processor_list(input_ids, scores, cur_len):
|
||||||
processor = FlaxLogitsProcessorList(
|
processor = FlaxLogitsProcessorList(
|
||||||
[temp_dist_warp, top_k_warp, top_p_warp, min_dist_proc, bos_dist_proc, eos_dist_proc]
|
[
|
||||||
|
temp_dist_warp,
|
||||||
|
top_k_warp,
|
||||||
|
top_p_warp,
|
||||||
|
min_dist_proc,
|
||||||
|
bos_dist_proc,
|
||||||
|
eos_dist_proc,
|
||||||
|
no_repeat_proc,
|
||||||
|
]
|
||||||
)
|
)
|
||||||
scores = processor(input_ids, scores, cur_len=cur_len)
|
scores = processor(input_ids, scores, cur_len=cur_len)
|
||||||
return scores
|
return scores
|
||||||
|
Loading…
Reference in New Issue
Block a user