mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
TF: XLA logits processors - minimum length, forced eos, and forced bos (#16912)
* XLA min len, forced eos, and forced bos Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
This commit is contained in:
parent
f6210c49e2
commit
809dac48f9
@ -215,13 +215,18 @@ class TFMinLengthLogitsProcessor(TFLogitsProcessor):
|
||||
self.min_length = min_length
|
||||
self.eos_token_id = eos_token_id
|
||||
|
||||
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
|
||||
# TODO(Matt) - this if statement has to be rewritten for XLA. Leaving it now though since
|
||||
# generate is not XLA - compileable anyways
|
||||
if cur_len < self.min_length:
|
||||
eos_token_id_mask = tf.broadcast_to(tf.range(scores.shape[-1]) == self.eos_token_id, scores.shape)
|
||||
scores = tf.where(eos_token_id_mask, float("-inf"), scores)
|
||||
def _apply_eos_token_mask(self, scores: tf.Tensor) -> tf.Tensor:
|
||||
eos_token_id_mask = tf.range(scores.shape[-1]) == self.eos_token_id
|
||||
scores = tf.where(eos_token_id_mask, float("-inf"), scores)
|
||||
return scores
|
||||
|
||||
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
|
||||
# applies eos token masking if the first argument is true
|
||||
scores = tf.cond(
|
||||
tf.less(cur_len, self.min_length),
|
||||
lambda: self._apply_eos_token_mask(scores),
|
||||
lambda: tf.identity(scores),
|
||||
)
|
||||
return scores
|
||||
|
||||
|
||||
|
@ -18,6 +18,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from parameterized import parameterized
|
||||
from transformers import is_tf_available
|
||||
from transformers.testing_utils import require_tf
|
||||
|
||||
@ -47,12 +48,15 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
||||
scores = tf.ones((batch_size, length), dtype=tf.float32) / length
|
||||
return scores
|
||||
|
||||
def test_min_length_dist_processor(self):
|
||||
@parameterized.expand([(False,), (True,)])
|
||||
def test_min_length_dist_processor(self, use_xla):
|
||||
vocab_size = 20
|
||||
batch_size = 4
|
||||
eos_token_id = 0
|
||||
|
||||
min_dist_processor = TFMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
|
||||
if use_xla:
|
||||
min_dist_processor = tf.function(min_dist_processor, jit_compile=True)
|
||||
|
||||
# check that min length is applied at length 5
|
||||
cur_len = 5
|
||||
@ -256,12 +260,15 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
||||
[[True, True, False, True, True], [True, True, True, False, True]],
|
||||
)
|
||||
|
||||
def test_forced_bos_token_logits_processor(self):
|
||||
@parameterized.expand([(False,), (True,)])
|
||||
def test_forced_bos_token_logits_processor(self, use_xla):
|
||||
vocab_size = 20
|
||||
batch_size = 4
|
||||
bos_token_id = 0
|
||||
|
||||
logits_processor = TFForcedBOSTokenLogitsProcessor(bos_token_id=bos_token_id)
|
||||
if use_xla:
|
||||
logits_processor = tf.function(logits_processor, jit_compile=True)
|
||||
|
||||
# check that all scores are -inf except the bos_token_id score
|
||||
cur_len = 1
|
||||
@ -280,13 +287,16 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
||||
scores = logits_processor(input_ids, scores, cur_len)
|
||||
self.assertFalse(tf.math.reduce_any(tf.math.is_inf((scores))))
|
||||
|
||||
def test_forced_eos_token_logits_processor(self):
|
||||
@parameterized.expand([(False,), (True,)])
|
||||
def test_forced_eos_token_logits_processor(self, use_xla):
|
||||
vocab_size = 20
|
||||
batch_size = 4
|
||||
eos_token_id = 0
|
||||
max_length = 5
|
||||
|
||||
logits_processor = TFForcedEOSTokenLogitsProcessor(max_length=max_length, eos_token_id=eos_token_id)
|
||||
if use_xla:
|
||||
logits_processor = tf.function(logits_processor, jit_compile=True)
|
||||
|
||||
# check that all scores are -inf except the eos_token_id when max_length-1 is reached
|
||||
cur_len = 4
|
||||
|
Loading…
Reference in New Issue
Block a user