mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[Flax] Add Beam Search (#12131)
* fix_torch_device_generate_test * remove @ * push new logit processors * add processors * save first working version * save intermediate * finish * make style * make fix-copies * finish * Update tests/test_modeling_flax_bart.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Suraj Patil <surajp815@gmail.com> Co-authored-by: Patrick von Platen <patrick@huggingface.co> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
parent
802ffaff0d
commit
c3c39f7e84
@ -186,6 +186,15 @@ generation.
|
||||
.. autoclass:: transformers.FlaxTopKLogitsWarper
|
||||
:members: __call__
|
||||
|
||||
.. autoclass:: transformers.FlaxForcedBOSTokenLogitsProcessor
|
||||
:members: __call__
|
||||
|
||||
.. autoclass:: transformers.FlaxForcedEOSTokenLogitsProcessor
|
||||
:members: __call__
|
||||
|
||||
.. autoclass:: transformers.FlaxMinLengthLogitsProcessor
|
||||
:members: __call__
|
||||
|
||||
|
||||
StoppingCriteria
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
@ -1486,9 +1486,12 @@ else:
|
||||
# FLAX-backed objects
|
||||
if is_flax_available():
|
||||
_import_structure["generation_flax_logits_process"] = [
|
||||
"FlaxForcedBOSTokenLogitsProcessor",
|
||||
"FlaxForcedEOSTokenLogitsProcessor",
|
||||
"FlaxLogitsProcessor",
|
||||
"FlaxLogitsProcessorList",
|
||||
"FlaxLogitsWarper",
|
||||
"FlaxMinLengthLogitsProcessor",
|
||||
"FlaxTemperatureLogitsWarper",
|
||||
"FlaxTopKLogitsWarper",
|
||||
"FlaxTopPLogitsWarper",
|
||||
@ -2814,9 +2817,12 @@ if TYPE_CHECKING:
|
||||
|
||||
if is_flax_available():
|
||||
from .generation_flax_logits_process import (
|
||||
FlaxForcedBOSTokenLogitsProcessor,
|
||||
FlaxForcedEOSTokenLogitsProcessor,
|
||||
FlaxLogitsProcessor,
|
||||
FlaxLogitsProcessorList,
|
||||
FlaxLogitsWarper,
|
||||
FlaxMinLengthLogitsProcessor,
|
||||
FlaxTemperatureLogitsWarper,
|
||||
FlaxTopKLogitsWarper,
|
||||
FlaxTopPLogitsWarper,
|
||||
|
@ -81,16 +81,18 @@ class FlaxLogitsProcessorList(list):
|
||||
"""
|
||||
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, **kwargs) -> jax_xla.DeviceArray:
|
||||
def __call__(
|
||||
self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int, **kwargs
|
||||
) -> jax_xla.DeviceArray:
|
||||
for processor in self:
|
||||
function_args = inspect.signature(processor.__call__).parameters
|
||||
if len(function_args) > 2:
|
||||
if len(function_args) > 3:
|
||||
assert all(
|
||||
arg in kwargs for arg in list(function_args.keys())[2:]
|
||||
), f"Make sure that all the required parameters: {list(function_args.keys())} for {processor.__class__} are passed to the logits processor."
|
||||
scores = processor(input_ids, scores, **kwargs)
|
||||
scores = processor(input_ids, scores, cur_len, **kwargs)
|
||||
else:
|
||||
scores = processor(input_ids, scores)
|
||||
scores = processor(input_ids, scores, cur_len)
|
||||
return scores
|
||||
|
||||
|
||||
@ -109,7 +111,9 @@ class FlaxTemperatureLogitsWarper(FlaxLogitsWarper):
|
||||
|
||||
self.temperature = temperature
|
||||
|
||||
def __call__(self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray) -> jax_xla.DeviceArray:
|
||||
def __call__(
|
||||
self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int
|
||||
) -> jax_xla.DeviceArray:
|
||||
scores = scores / self.temperature
|
||||
return scores
|
||||
|
||||
@ -137,7 +141,9 @@ class FlaxTopPLogitsWarper(FlaxLogitsWarper):
|
||||
self.filter_value = filter_value
|
||||
self.min_tokens_to_keep = min_tokens_to_keep
|
||||
|
||||
def __call__(self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray) -> jax_xla.DeviceArray:
|
||||
def __call__(
|
||||
self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int
|
||||
) -> jax_xla.DeviceArray:
|
||||
topk_scores, topk_indices = lax.top_k(scores, scores.shape[-1])
|
||||
|
||||
mask_scores = jnp.full_like(scores, self.filter_value)
|
||||
@ -177,7 +183,9 @@ class FlaxTopKLogitsWarper(FlaxLogitsWarper):
|
||||
self.filter_value = filter_value
|
||||
self.min_tokens_to_keep = min_tokens_to_keep
|
||||
|
||||
def __call__(self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray) -> jax_xla.DeviceArray:
|
||||
def __call__(
|
||||
self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int
|
||||
) -> jax_xla.DeviceArray:
|
||||
batch_size, vocab_size = scores.shape
|
||||
next_scores_flat = jnp.full(batch_size * vocab_size, self.filter_value)
|
||||
|
||||
@ -190,3 +198,94 @@ class FlaxTopKLogitsWarper(FlaxLogitsWarper):
|
||||
next_scores_flat = jax.ops.index_update(next_scores_flat, topk_indices_flat, topk_scores_flat)
|
||||
next_scores = next_scores_flat.reshape(batch_size, vocab_size)
|
||||
return next_scores
|
||||
|
||||
|
||||
class FlaxForcedBOSTokenLogitsProcessor(FlaxLogitsProcessor):
|
||||
r"""
|
||||
:class:`~transformers.FlaxLogitsProcessor` that enforces the specified token as the first generated token.
|
||||
|
||||
Args:
|
||||
bos_token_id (:obj:`int`):
|
||||
The id of the token to force as the first generated token.
|
||||
"""
|
||||
|
||||
def __init__(self, bos_token_id: int):
|
||||
self.bos_token_id = bos_token_id
|
||||
|
||||
def __call__(
|
||||
self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int
|
||||
) -> jax_xla.DeviceArray:
|
||||
new_scores = jnp.full(scores.shape, -float("inf"))
|
||||
|
||||
apply_penalty = 1 - jnp.bool_(cur_len - 1)
|
||||
|
||||
scores = jnp.where(
|
||||
apply_penalty, jax.ops.index_update(new_scores, jax.ops.index[:, self.bos_token_id], 0), scores
|
||||
)
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
class FlaxForcedEOSTokenLogitsProcessor(FlaxLogitsProcessor):
|
||||
r"""
|
||||
:class:`~transformers.FlaxLogitsProcessor` that enforces the specified token as the last generated token when
|
||||
:obj:`max_length` is reached.
|
||||
|
||||
Args:
|
||||
max_length (:obj:`int`):
|
||||
The maximum length of the sequence to be generated.
|
||||
eos_token_id (:obj:`int`):
|
||||
The id of the token to force as the last generated token when :obj:`max_length` is reached.
|
||||
"""
|
||||
|
||||
def __init__(self, max_length: int, eos_token_id: int):
|
||||
self.max_length = max_length
|
||||
self.eos_token_id = eos_token_id
|
||||
|
||||
def __call__(
|
||||
self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int
|
||||
) -> jax_xla.DeviceArray:
|
||||
new_scores = jnp.full(scores.shape, -float("inf"))
|
||||
|
||||
apply_penalty = 1 - jnp.bool_(cur_len - self.max_length + 1)
|
||||
|
||||
scores = jnp.where(
|
||||
apply_penalty, jax.ops.index_update(new_scores, jax.ops.index[:, self.eos_token_id], 0), scores
|
||||
)
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
class FlaxMinLengthLogitsProcessor(FlaxLogitsProcessor):
|
||||
r"""
|
||||
:class:`transformers.FlaxLogitsProcessor` enforcing a min-length by setting EOS probability to 0.
|
||||
|
||||
Args:
|
||||
min_length (:obj:`int`):
|
||||
The minimum length below which the score of :obj:`eos_token_id` is set to :obj:`-float("Inf")`.
|
||||
eos_token_id (:obj:`int`):
|
||||
The id of the `end-of-sequence` token.
|
||||
"""
|
||||
|
||||
def __init__(self, min_length: int, eos_token_id: int):
|
||||
if not isinstance(min_length, int) or min_length < 0:
|
||||
raise ValueError(f"`min_length` has to be a positive integer, but is {min_length}")
|
||||
|
||||
if not isinstance(eos_token_id, int) or eos_token_id < 0:
|
||||
raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}")
|
||||
|
||||
self.min_length = min_length
|
||||
self.eos_token_id = eos_token_id
|
||||
|
||||
def __call__(
|
||||
self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int
|
||||
) -> jax_xla.DeviceArray:
|
||||
|
||||
# create boolean flag to decide if min length penalty should be applied
|
||||
apply_penalty = 1 - jnp.clip(cur_len - self.min_length, 0, 1)
|
||||
|
||||
scores = jnp.where(
|
||||
apply_penalty, jax.ops.index_update(scores, jax.ops.index[:, self.eos_token_id], -float("inf")), scores
|
||||
)
|
||||
|
||||
return scores
|
||||
|
@ -17,6 +17,8 @@
|
||||
|
||||
from typing import Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
import flax
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
@ -25,7 +27,10 @@ from jax import lax
|
||||
|
||||
from .file_utils import ModelOutput
|
||||
from .generation_flax_logits_process import (
|
||||
FlaxForcedBOSTokenLogitsProcessor,
|
||||
FlaxForcedEOSTokenLogitsProcessor,
|
||||
FlaxLogitsProcessorList,
|
||||
FlaxMinLengthLogitsProcessor,
|
||||
FlaxTemperatureLogitsWarper,
|
||||
FlaxTopKLogitsWarper,
|
||||
FlaxTopPLogitsWarper,
|
||||
@ -43,9 +48,8 @@ class FlaxGreedySearchOutput(ModelOutput):
|
||||
|
||||
|
||||
Args:
|
||||
sequences (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
|
||||
The generated sequences. If all batches finished early due to the :obj:`eos_token_id`, :obj:`sequences` is
|
||||
padded to :obj:`max_length`.
|
||||
sequences (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, max_length)`):
|
||||
The generated sequences.
|
||||
"""
|
||||
|
||||
sequences: jax_xla.DeviceArray = None
|
||||
@ -58,19 +62,35 @@ class FlaxSampleOutput(ModelOutput):
|
||||
|
||||
|
||||
Args:
|
||||
sequences (:obj:`torch.LongTensor` of shape :obj:`(batch_size, max_length)`):
|
||||
The generated sequences. If all batches finished early due to the :obj:`eos_token_id`, :obj:`sequences` is
|
||||
padded to :obj:`max_length`.
|
||||
sequences (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, max_length)`):
|
||||
The generated sequences.
|
||||
"""
|
||||
|
||||
sequences: jax_xla.DeviceArray = None
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class FlaxBeamSearchOutput(ModelOutput):
|
||||
"""
|
||||
Flax Base class for outputs of decoder-only generation models using greedy search.
|
||||
|
||||
|
||||
Args:
|
||||
sequences (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, max_length)`):
|
||||
The generated sequences.
|
||||
scores (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size,)`):
|
||||
The scores (log probabilites) of the generated sequences.
|
||||
"""
|
||||
|
||||
sequences: jax_xla.DeviceArray = None
|
||||
scores: jax_xla.DeviceArray = None
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class GreedyState:
|
||||
cur_len: jax_xla.DeviceArray
|
||||
sequences: jax_xla.DeviceArray
|
||||
current_token: jax_xla.DeviceArray
|
||||
running_token: jax_xla.DeviceArray
|
||||
is_sent_finished: jax_xla.DeviceArray
|
||||
model_kwargs: Dict[str, jax_xla.DeviceArray]
|
||||
|
||||
@ -79,12 +99,23 @@ class GreedyState:
|
||||
class SampleState:
|
||||
cur_len: jax_xla.DeviceArray
|
||||
sequences: jax_xla.DeviceArray
|
||||
current_token: jax_xla.DeviceArray
|
||||
running_token: jax_xla.DeviceArray
|
||||
is_sent_finished: jax_xla.DeviceArray
|
||||
prng_key: jax_xla.DeviceArray
|
||||
model_kwargs: Dict[str, jax_xla.DeviceArray]
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class BeamSearchState:
|
||||
cur_len: jax_xla.DeviceArray
|
||||
running_sequences: jax_xla.DeviceArray
|
||||
running_scores: jax_xla.DeviceArray
|
||||
sequences: jax_xla.DeviceArray
|
||||
scores: jax_xla.DeviceArray
|
||||
is_sent_finished: jax_xla.DeviceArray
|
||||
model_kwargs: Dict[str, jax_xla.DeviceArray]
|
||||
|
||||
|
||||
class FlaxGenerationMixin:
|
||||
"""
|
||||
A class containing all of the functions supporting generation, to be used as a mixin in
|
||||
@ -110,6 +141,10 @@ class FlaxGenerationMixin:
|
||||
model_kwargs["encoder_outputs"] = self.encode(input_ids, return_dict=True, **encoder_kwargs)
|
||||
return model_kwargs
|
||||
|
||||
@staticmethod
|
||||
def _expand_to_num_beams(tensor, num_beams):
|
||||
return jnp.broadcast_to(tensor[:, None], (tensor.shape[0], num_beams) + tensor.shape[1:])
|
||||
|
||||
def generate(
|
||||
self,
|
||||
input_ids: jax_xla.DeviceArray,
|
||||
@ -123,6 +158,13 @@ class FlaxGenerationMixin:
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
temperature: Optional[float] = None,
|
||||
num_beams: Optional[int] = None,
|
||||
no_repeat_ngram_size: Optional[int] = None,
|
||||
min_length: Optional[int] = None,
|
||||
forced_bos_token_id: Optional[int] = None,
|
||||
forced_eos_token_id: Optional[int] = None,
|
||||
length_penalty: Optional[float] = None,
|
||||
early_stopping: Optional[bool] = None,
|
||||
trace: bool = True,
|
||||
params: Optional[Dict[str, jax_xla.DeviceArray]] = None,
|
||||
**model_kwargs,
|
||||
@ -159,6 +201,8 @@ class FlaxGenerationMixin:
|
||||
The id of the `beginning-of-sequence` token.
|
||||
eos_token_id (:obj:`int`, `optional`):
|
||||
The id of the `end-of-sequence` token.
|
||||
num_beams (:obj:`int`, `optional`, defaults to 1):
|
||||
Number of beams for beam search. 1 means no beam search.
|
||||
decoder_start_token_id (:obj:`int`, `optional`):
|
||||
If an encoder-decoder model starts decoding with a different token than `bos`, the id of that token.
|
||||
trace (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
@ -204,9 +248,27 @@ class FlaxGenerationMixin:
|
||||
input_ids = jnp.ones((input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
|
||||
|
||||
do_sample = do_sample if do_sample is not None else self.config.do_sample
|
||||
num_beams = num_beams if num_beams is not None else self.config.num_beams
|
||||
|
||||
if do_sample:
|
||||
if not do_sample and num_beams == 1:
|
||||
logits_processor = self._get_logits_processor(
|
||||
no_repeat_ngram_size, min_length, max_length, eos_token_id, forced_bos_token_id, forced_eos_token_id
|
||||
)
|
||||
return self._greedy_search(
|
||||
input_ids,
|
||||
max_length,
|
||||
pad_token_id,
|
||||
eos_token_id,
|
||||
logits_processor=logits_processor,
|
||||
trace=trace,
|
||||
params=params,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
elif do_sample and num_beams == 1:
|
||||
logits_warper = self._get_logits_warper(top_k=top_k, top_p=top_p, temperature=temperature)
|
||||
logits_processor = self._get_logits_processor(
|
||||
no_repeat_ngram_size, min_length, max_length, eos_token_id, forced_bos_token_id, forced_eos_token_id
|
||||
)
|
||||
return self._sample(
|
||||
input_ids,
|
||||
max_length,
|
||||
@ -214,20 +276,43 @@ class FlaxGenerationMixin:
|
||||
eos_token_id,
|
||||
prng_key,
|
||||
logits_warper=logits_warper,
|
||||
logits_processor=logits_processor,
|
||||
trace=trace,
|
||||
params=params,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
elif not do_sample and num_beams > 1:
|
||||
# broadcast input_ids & encoder_outputs
|
||||
input_ids = self._expand_to_num_beams(input_ids, num_beams=num_beams)
|
||||
|
||||
if "encoder_outputs" in model_kwargs:
|
||||
model_kwargs["encoder_outputs"]["last_hidden_state"] = self._expand_to_num_beams(
|
||||
model_kwargs["encoder_outputs"]["last_hidden_state"], num_beams=num_beams
|
||||
)
|
||||
|
||||
if "attention_mask" in model_kwargs:
|
||||
model_kwargs["attention_mask"] = self._expand_to_num_beams(
|
||||
model_kwargs["attention_mask"], num_beams=num_beams
|
||||
)
|
||||
|
||||
logits_processor = self._get_logits_processor(
|
||||
no_repeat_ngram_size, min_length, max_length, eos_token_id, forced_bos_token_id, forced_eos_token_id
|
||||
)
|
||||
|
||||
return self._beam_search(
|
||||
input_ids,
|
||||
max_length,
|
||||
pad_token_id,
|
||||
eos_token_id,
|
||||
length_penalty=length_penalty,
|
||||
early_stopping=early_stopping,
|
||||
logits_processor=logits_processor,
|
||||
trace=trace,
|
||||
params=params,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
else:
|
||||
return self._greedy_search(
|
||||
input_ids,
|
||||
max_length,
|
||||
pad_token_id,
|
||||
eos_token_id,
|
||||
trace=trace,
|
||||
params=params,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
raise NotImplementedError("`Beam sampling is currently not implemented.")
|
||||
|
||||
def _get_logits_warper(
|
||||
self, top_k: int = None, top_p: float = None, temperature: float = None
|
||||
@ -255,12 +340,51 @@ class FlaxGenerationMixin:
|
||||
|
||||
return warpers
|
||||
|
||||
def _get_logits_processor(
|
||||
self,
|
||||
no_repeat_ngram_size: int,
|
||||
min_length: int,
|
||||
max_length: int,
|
||||
eos_token_id: int,
|
||||
forced_bos_token_id: int,
|
||||
forced_eos_token_id: int,
|
||||
) -> FlaxLogitsProcessorList:
|
||||
"""
|
||||
This class returns a :obj:`~transformers.FlaxLogitsProcessorList` list object that contains all relevant
|
||||
:obj:`~transformers.FlaxLogitsProcessor` instances used to modify the scores of the language model head.
|
||||
"""
|
||||
processors = FlaxLogitsProcessorList()
|
||||
|
||||
# init warp parameters
|
||||
no_repeat_ngram_size = (
|
||||
no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
|
||||
)
|
||||
min_length = min_length if min_length is not None else self.config.min_length
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||
forced_bos_token_id = (
|
||||
forced_bos_token_id if forced_bos_token_id is not None else self.config.forced_bos_token_id
|
||||
)
|
||||
forced_eos_token_id = (
|
||||
forced_eos_token_id if forced_eos_token_id is not None else self.config.forced_eos_token_id
|
||||
)
|
||||
|
||||
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
|
||||
# all samplers can be found in `generation_utils_samplers.py`
|
||||
if min_length is not None and eos_token_id is not None and min_length > -1:
|
||||
processors.append(FlaxMinLengthLogitsProcessor(min_length, eos_token_id))
|
||||
if forced_bos_token_id is not None:
|
||||
processors.append(FlaxForcedBOSTokenLogitsProcessor(forced_bos_token_id))
|
||||
if forced_eos_token_id is not None:
|
||||
processors.append(FlaxForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id))
|
||||
return processors
|
||||
|
||||
def _greedy_search(
|
||||
self,
|
||||
input_ids: None,
|
||||
max_length: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
logits_processor: Optional[FlaxLogitsProcessorList] = None,
|
||||
trace: bool = True,
|
||||
params: Optional[Dict[str, jax_xla.DeviceArray]] = None,
|
||||
model_kwargs: Optional[Dict[str, jax_xla.DeviceArray]] = None,
|
||||
@ -293,7 +417,7 @@ class FlaxGenerationMixin:
|
||||
state = GreedyState(
|
||||
cur_len=cur_len,
|
||||
sequences=sequences,
|
||||
current_token=input_ids,
|
||||
running_token=input_ids,
|
||||
is_sent_finished=is_sent_finished,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
@ -307,8 +431,13 @@ class FlaxGenerationMixin:
|
||||
|
||||
def greedy_search_body_fn(state):
|
||||
"""state update fn."""
|
||||
model_outputs = model(state.current_token, params=params, **state.model_kwargs)
|
||||
next_token = jnp.argmax(model_outputs.logits[:, -1], axis=-1)
|
||||
model_outputs = model(state.running_token, params=params, **state.model_kwargs)
|
||||
logits = model_outputs.logits[:, -1]
|
||||
|
||||
# apply min_length, ...
|
||||
logits = logits_processor(state.sequences, logits, state.cur_len)
|
||||
|
||||
next_token = jnp.argmax(logits, axis=-1)
|
||||
|
||||
next_is_sent_finished = state.is_sent_finished | (next_token == eos_token_id)
|
||||
next_token = next_token * ~next_is_sent_finished + pad_token_id * next_is_sent_finished
|
||||
@ -319,7 +448,7 @@ class FlaxGenerationMixin:
|
||||
return GreedyState(
|
||||
cur_len=state.cur_len + 1,
|
||||
sequences=next_sequences,
|
||||
current_token=next_token,
|
||||
running_token=next_token,
|
||||
is_sent_finished=next_is_sent_finished,
|
||||
model_kwargs=next_model_kwargs,
|
||||
)
|
||||
@ -342,6 +471,7 @@ class FlaxGenerationMixin:
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
prng_key: Optional[jax_xla.DeviceArray] = None,
|
||||
logits_processor: Optional[FlaxLogitsProcessorList] = None,
|
||||
logits_warper: Optional[FlaxLogitsProcessorList] = None,
|
||||
trace: bool = True,
|
||||
params: Optional[Dict[str, jax_xla.DeviceArray]] = None,
|
||||
@ -377,7 +507,7 @@ class FlaxGenerationMixin:
|
||||
state = SampleState(
|
||||
cur_len=cur_len,
|
||||
sequences=sequences,
|
||||
current_token=input_ids,
|
||||
running_token=input_ids,
|
||||
is_sent_finished=is_sent_finished,
|
||||
prng_key=prng_key,
|
||||
model_kwargs=model_kwargs,
|
||||
@ -393,12 +523,14 @@ class FlaxGenerationMixin:
|
||||
def sample_search_body_fn(state):
|
||||
"""state update fn."""
|
||||
prng_key, prng_key_next = jax.random.split(state.prng_key)
|
||||
model_outputs = model(state.current_token, params=params, **state.model_kwargs)
|
||||
model_outputs = model(state.running_token, params=params, **state.model_kwargs)
|
||||
|
||||
logits = model_outputs.logits[:, -1]
|
||||
|
||||
# apply min_length, ...
|
||||
logits = logits_processor(state.sequences, logits, state.cur_len)
|
||||
# apply top_k, top_k, temperature
|
||||
logits = logits_warper(state.sequences, logits)
|
||||
logits = logits_warper(logits, logits, state.cur_len)
|
||||
|
||||
next_token = jax.random.categorical(prng_key, model_outputs.logits[:, -1], axis=-1)
|
||||
|
||||
@ -412,7 +544,7 @@ class FlaxGenerationMixin:
|
||||
return SampleState(
|
||||
cur_len=state.cur_len + 1,
|
||||
sequences=next_sequences,
|
||||
current_token=next_token,
|
||||
running_token=next_token,
|
||||
is_sent_finished=next_is_sent_finished,
|
||||
model_kwargs=next_model_kwargs,
|
||||
prng_key=prng_key_next,
|
||||
@ -428,3 +560,251 @@ class FlaxGenerationMixin:
|
||||
state = lax.while_loop(sample_search_cond_fn, sample_search_body_fn, state)
|
||||
|
||||
return FlaxSampleOutput(sequences=state.sequences)
|
||||
|
||||
def _beam_search(
|
||||
self,
|
||||
input_ids: None,
|
||||
max_length: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
length_penalty: Optional[float] = None,
|
||||
early_stopping: Optional[bool] = None,
|
||||
logits_processor: Optional[FlaxLogitsProcessorList] = None,
|
||||
trace: bool = True,
|
||||
params: Optional[Dict[str, jax_xla.DeviceArray]] = None,
|
||||
model_kwargs: Optional[Dict[str, jax_xla.DeviceArray]] = None,
|
||||
):
|
||||
"""
|
||||
This beam search function is heavily inspired by Flax's official example:
|
||||
https://github.com/google/flax/blob/master/examples/wmt/train.py#L254
|
||||
"""
|
||||
|
||||
def flatten_beam_dim(tensor):
|
||||
"""Flattens the first two dimensions of a non-scalar array."""
|
||||
# ignore scalars (e.g. cache index)
|
||||
if tensor.ndim == 0:
|
||||
return tensor
|
||||
return tensor.reshape((tensor.shape[0] * tensor.shape[1],) + tensor.shape[2:])
|
||||
|
||||
def unflatten_beam_dim(tensor, batch_size, num_beams):
|
||||
"""Unflattens the first, flat batch*beam dimension of a non-scalar array."""
|
||||
# ignore scalars (e.g. cache index)
|
||||
if tensor.ndim == 0:
|
||||
return tensor
|
||||
return tensor.reshape((batch_size, num_beams) + tensor.shape[1:])
|
||||
|
||||
def gather_beams(nested, beam_indices, batch_size, new_num_beams):
|
||||
"""
|
||||
Gathers the beam slices indexed by beam_indices into new beam array.
|
||||
"""
|
||||
batch_indices = jnp.reshape(
|
||||
jnp.arange(batch_size * new_num_beams) // new_num_beams, (batch_size, new_num_beams)
|
||||
)
|
||||
|
||||
def gather_fn(tensor):
|
||||
# ignore scalars (e.g. cache index)
|
||||
if tensor.ndim == 0:
|
||||
return tensor
|
||||
else:
|
||||
return tensor[batch_indices, beam_indices]
|
||||
|
||||
return jax.tree_map(gather_fn, nested)
|
||||
|
||||
# init values
|
||||
max_length = max_length if max_length is not None else self.config.max_length
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
|
||||
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
|
||||
|
||||
batch_size, num_beams, cur_len = input_ids.shape
|
||||
|
||||
eos_token_id = jnp.array(eos_token_id)
|
||||
pad_token_id = jnp.array(pad_token_id)
|
||||
cur_len = jnp.array(cur_len)
|
||||
|
||||
# per batch,beam-item holding current token in loop.
|
||||
sequences = jnp.full((batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32)
|
||||
running_sequences = jnp.full((batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32)
|
||||
running_sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0, 0))
|
||||
|
||||
# per batch,beam-item state bit indicating if sentence has finished.
|
||||
is_sent_finished = jnp.zeros((batch_size, num_beams), dtype=jnp.bool_)
|
||||
|
||||
# per batch,beam-item score, logprobs
|
||||
running_scores = jnp.tile(jnp.array([0.0] + [np.array(-1.0e7)] * (num_beams - 1)), [batch_size, 1])
|
||||
scores = jnp.ones((batch_size, num_beams)) * np.array(-1.0e7)
|
||||
|
||||
# For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop
|
||||
# and pass it the `encoder_outputs`, which are part of the `model_kwargs`.
|
||||
model = self.decode if self.config.is_encoder_decoder else self
|
||||
|
||||
# flatten beam dim
|
||||
if "encoder_outputs" in model_kwargs:
|
||||
model_kwargs["encoder_outputs"]["last_hidden_state"] = flatten_beam_dim(
|
||||
model_kwargs["encoder_outputs"]["last_hidden_state"]
|
||||
)
|
||||
if "attention_mask" in model_kwargs:
|
||||
model_kwargs["attention_mask"] = flatten_beam_dim(model_kwargs["attention_mask"])
|
||||
|
||||
# initialize model specific kwargs
|
||||
model_kwargs = self.prepare_inputs_for_generation(flatten_beam_dim(input_ids), max_length, **model_kwargs)
|
||||
|
||||
# initialize state
|
||||
state = BeamSearchState(
|
||||
cur_len=cur_len,
|
||||
running_sequences=running_sequences,
|
||||
running_scores=running_scores,
|
||||
sequences=sequences,
|
||||
scores=scores,
|
||||
is_sent_finished=is_sent_finished,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
|
||||
def beam_search_cond_fn(state):
|
||||
"""beam search state termination condition fn."""
|
||||
|
||||
# 1. is less than max length?
|
||||
not_max_length_yet = state.cur_len < max_length
|
||||
|
||||
# 2. can the new beams still improve?
|
||||
best_running_score = state.running_scores[:, -1:] / (max_length ** length_penalty)
|
||||
worst_finished_score = jnp.where(
|
||||
state.is_sent_finished, jnp.min(state.scores, axis=1, keepdims=True), np.array(-1.0e7)
|
||||
)
|
||||
improvement_still_possible = jnp.all(worst_finished_score < best_running_score)
|
||||
|
||||
# 3. is there still a beam that has not finished?
|
||||
still_open_beam = ~(jnp.all(state.is_sent_finished) & early_stopping)
|
||||
|
||||
return not_max_length_yet & still_open_beam & improvement_still_possible
|
||||
|
||||
def beam_search_body_fn(state):
|
||||
"""beam search state update fn."""
|
||||
# 1. Forward current tokens
|
||||
# Collect the current position slice along length to feed the fast
|
||||
# autoregressive decoder model. Flatten the beam dimension into batch
|
||||
# dimension for feeding into the model.
|
||||
# unflatten beam dimension
|
||||
# Unflatten beam dimension in attention cache arrays
|
||||
input_token = flatten_beam_dim(
|
||||
lax.dynamic_slice(state.running_sequences, (0, 0, state.cur_len - 1), (batch_size, num_beams, 1))
|
||||
)
|
||||
model_outputs = model(input_token, params=params, **state.model_kwargs)
|
||||
logits = unflatten_beam_dim(model_outputs.logits[:, 0], batch_size, num_beams)
|
||||
cache = jax.tree_map(
|
||||
lambda tensor: unflatten_beam_dim(tensor, batch_size, num_beams), model_outputs.past_key_values
|
||||
)
|
||||
|
||||
# 2. Compute log probs
|
||||
# get log probabilities from logits,
|
||||
# process logits with processors (*e.g.* min_length, ...), and
|
||||
# add new logprobs to existing running logprobs scores.
|
||||
log_probs = jax.nn.log_softmax(logits)
|
||||
log_probs = logits_processor(
|
||||
flatten_beam_dim(running_sequences), flatten_beam_dim(log_probs), state.cur_len
|
||||
)
|
||||
log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams)
|
||||
log_probs = log_probs + jnp.expand_dims(state.running_scores, axis=2)
|
||||
vocab_size = log_probs.shape[2]
|
||||
log_probs = log_probs.reshape((batch_size, num_beams * vocab_size))
|
||||
|
||||
# 3. Retrieve top-K
|
||||
# Each item in batch has num_beams * vocab_size candidate sequences.
|
||||
# For each item, get the top 2*k candidates with the highest log-
|
||||
# probabilities. We gather the top 2*K beams here so that even if the best
|
||||
# K sequences reach EOS simultaneously, we have another K sequences
|
||||
# remaining to continue the live beam search.
|
||||
# Gather the top 2*K scores from _all_ beams.
|
||||
# Gather 2*k top beams.
|
||||
# Recover the beam index by floor division.
|
||||
# Recover token id by modulo division and expand Id array for broadcasting.
|
||||
# Update sequences for the 2*K top-k new sequences.
|
||||
beams_to_keep = 2 * num_beams
|
||||
topk_log_probs, topk_indices = lax.top_k(log_probs, k=beams_to_keep)
|
||||
topk_beam_indices = topk_indices // vocab_size
|
||||
topk_running_sequences = gather_beams(
|
||||
state.running_sequences, topk_beam_indices, batch_size, beams_to_keep
|
||||
)
|
||||
topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2)
|
||||
topk_sequences = lax.dynamic_update_slice(topk_running_sequences, topk_ids, (0, 0, state.cur_len))
|
||||
|
||||
# 4. Check which sequences have ended
|
||||
# Update current sequences:
|
||||
# Did any of these sequences reach an end marker?
|
||||
# To prevent these just finished sequences from being added to the current sequences
|
||||
# set of active beam search sequences, set their log probs to a very large
|
||||
# negative value.
|
||||
did_topk_just_finished = topk_sequences[:, :, state.cur_len] == eos_token_id
|
||||
topk_log_probs = topk_log_probs + did_topk_just_finished * np.array(-1.0e7)
|
||||
|
||||
# 5. Get running sequences scores for next
|
||||
# Determine the top k beam indices (from top 2*k beams) from log probs
|
||||
# and gather top k beams (from top 2*k beams).
|
||||
next_topk_indices = jnp.flip(lax.top_k(topk_log_probs, k=num_beams)[1], axis=1)
|
||||
next_running_sequences, next_running_scores = gather_beams(
|
||||
[topk_sequences, topk_log_probs], next_topk_indices, batch_size, num_beams
|
||||
)
|
||||
|
||||
# 6. Process topk logits
|
||||
# Further process log probs:
|
||||
# - add length penalty
|
||||
# - make sure no scores can be added anymore if beam is full
|
||||
# - make sure still running sequences cannot be chosen as finalized beam
|
||||
topk_log_probs = topk_log_probs / (state.cur_len ** length_penalty)
|
||||
beams_in_batch_are_full = (
|
||||
jnp.broadcast_to(state.is_sent_finished.all(axis=-1, keepdims=True), did_topk_just_finished.shape)
|
||||
& early_stopping
|
||||
)
|
||||
add_penalty = ~did_topk_just_finished | beams_in_batch_are_full
|
||||
topk_log_probs += add_penalty * np.array(-1.0e7)
|
||||
|
||||
# 7. Get scores, sequences, is sentence finished for next.
|
||||
# Combine sequences, scores, and flags along the beam dimension and compare
|
||||
# new finished sequence scores to existing finished scores and select the
|
||||
# best from the new set of beams
|
||||
merged_sequences = jnp.concatenate([state.sequences, topk_sequences], axis=1)
|
||||
merged_scores = jnp.concatenate([state.scores, topk_log_probs], axis=1)
|
||||
merged_is_sent_finished = jnp.concatenate([state.is_sent_finished, did_topk_just_finished], axis=1)
|
||||
topk_merged_indices = jnp.flip(lax.top_k(merged_scores, k=num_beams)[1], axis=1)
|
||||
next_sequences, next_scores, next_is_sent_finished = gather_beams(
|
||||
[merged_sequences, merged_scores, merged_is_sent_finished], topk_merged_indices, batch_size, num_beams
|
||||
)
|
||||
|
||||
# 8. Update model kwargs.
|
||||
# Determine the top k beam indices from the original set of all beams.
|
||||
# With these, gather the top k beam-associated caches.
|
||||
next_running_indices = gather_beams(topk_beam_indices, next_topk_indices, batch_size, num_beams)
|
||||
next_cache = gather_beams(cache, next_running_indices, batch_size, num_beams)
|
||||
model_outputs["past_key_values"] = jax.tree_map(lambda x: flatten_beam_dim(x), next_cache)
|
||||
next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs)
|
||||
|
||||
return BeamSearchState(
|
||||
cur_len=state.cur_len + 1,
|
||||
running_scores=next_running_scores,
|
||||
running_sequences=next_running_sequences,
|
||||
scores=next_scores,
|
||||
sequences=next_sequences,
|
||||
is_sent_finished=next_is_sent_finished,
|
||||
model_kwargs=next_model_kwargs,
|
||||
)
|
||||
|
||||
# The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
|
||||
state = beam_search_body_fn(state)
|
||||
|
||||
if not trace:
|
||||
state = self._run_loop_in_debug(beam_search_cond_fn, beam_search_body_fn, state)
|
||||
else:
|
||||
state = lax.while_loop(beam_search_cond_fn, beam_search_body_fn, state)
|
||||
|
||||
# Account for the edge-case where there are no finished sequences for a
|
||||
# particular batch item. If so, return running sequences for that batch item.
|
||||
none_finished = jnp.any(state.is_sent_finished, axis=1)
|
||||
sequences = jnp.where(none_finished[:, None, None], state.sequences, state.running_sequences)
|
||||
scores = jnp.where(none_finished[:, None], state.scores, state.running_scores)
|
||||
|
||||
# take best beam for each batch
|
||||
sequences = sequences[:, -1]
|
||||
scores = scores[:, -1]
|
||||
|
||||
return FlaxBeamSearchOutput(sequences=sequences, scores=scores)
|
||||
|
@ -2,6 +2,24 @@
|
||||
from ..file_utils import requires_backends
|
||||
|
||||
|
||||
class FlaxForcedBOSTokenLogitsProcessor:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
|
||||
class FlaxForcedEOSTokenLogitsProcessor:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
|
||||
class FlaxLogitsProcessor:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
@ -25,6 +43,15 @@ class FlaxLogitsWarper:
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxMinLengthLogitsProcessor:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
|
||||
class FlaxTemperatureLogitsWarper:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
@ -28,7 +28,10 @@ if is_flax_available():
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from transformers.generation_flax_logits_process import (
|
||||
FlaxForcedBOSTokenLogitsProcessor,
|
||||
FlaxForcedEOSTokenLogitsProcessor,
|
||||
FlaxLogitsProcessorList,
|
||||
FlaxMinLengthLogitsProcessor,
|
||||
FlaxTemperatureLogitsWarper,
|
||||
FlaxTopKLogitsWarper,
|
||||
FlaxTopPLogitsWarper,
|
||||
@ -57,8 +60,8 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
temp_dist_warper_sharper = FlaxTemperatureLogitsWarper(temperature=0.5)
|
||||
temp_dist_warper_smoother = FlaxTemperatureLogitsWarper(temperature=1.3)
|
||||
|
||||
warped_prob_sharp = jax.nn.softmax(temp_dist_warper_sharper(input_ids, scores.copy()), axis=-1)
|
||||
warped_prob_smooth = jax.nn.softmax(temp_dist_warper_smoother(input_ids, scores.copy()), axis=-1)
|
||||
warped_prob_sharp = jax.nn.softmax(temp_dist_warper_sharper(input_ids, scores.copy(), cur_len=None), axis=-1)
|
||||
warped_prob_smooth = jax.nn.softmax(temp_dist_warper_smoother(input_ids, scores.copy(), cur_len=None), axis=-1)
|
||||
|
||||
# uniform distribution stays uniform
|
||||
self.assertTrue(jnp.allclose(probs[0, :], warped_prob_sharp[0, :], atol=1e-3))
|
||||
@ -83,7 +86,7 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
|
||||
top_k_warp = FlaxTopKLogitsWarper(3)
|
||||
|
||||
scores = top_k_warp(input_ids, ramp_logits)
|
||||
scores = top_k_warp(input_ids, ramp_logits, cur_len=None)
|
||||
|
||||
# check that correct tokens are filtered
|
||||
self.assertListEqual(jnp.isinf(scores[0]).tolist(), 7 * [True] + 3 * [False])
|
||||
@ -94,7 +97,7 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
top_k_warp_safety_check = FlaxTopKLogitsWarper(top_k=1, filter_value=0.0, min_tokens_to_keep=3)
|
||||
|
||||
ramp_logits = np.broadcast_to(np.arange(length)[None, :], (batch_size, length)).copy()
|
||||
scores = top_k_warp_safety_check(input_ids, ramp_logits)
|
||||
scores = top_k_warp_safety_check(input_ids, ramp_logits, cur_len=None)
|
||||
|
||||
# min_tokens overwrites k: 3 tokens are kept => 2 tokens are nullified
|
||||
self.assertListEqual((scores == 0.0).sum(axis=-1).tolist(), [2, 2])
|
||||
@ -108,7 +111,7 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
dist = np.log(np.array([[0.3, 0.1, 0.1, 0.5], [0.15, 0.3, 0.3, 0.25]]))
|
||||
|
||||
top_p_warp = FlaxTopPLogitsWarper(0.7)
|
||||
filtered_dist = np.exp(top_p_warp(input_ids, dist))
|
||||
filtered_dist = np.exp(top_p_warp(input_ids, dist, cur_len=None))
|
||||
|
||||
# dist should be filtered to keep min num values so that sum is >= 0.7
|
||||
# exp (-inf) => 0
|
||||
@ -125,15 +128,81 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
|
||||
# make sure at least 2 tokens are kept
|
||||
top_p_warp = FlaxTopPLogitsWarper(0.9, min_tokens_to_keep=2, filter_value=0.0)
|
||||
filtered_dist = top_p_warp(input_ids, ramp_logits)
|
||||
filtered_dist = top_p_warp(input_ids, ramp_logits, cur_len=None)
|
||||
|
||||
# first batch should keep three tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2.
|
||||
self.assertListEqual((filtered_dist != 0.0).sum(axis=-1).tolist(), [3, 2])
|
||||
|
||||
def test_min_length_dist_processor(self):
|
||||
vocab_size = 20
|
||||
batch_size = 4
|
||||
eos_token_id = 0
|
||||
|
||||
min_dist_processor = FlaxMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
|
||||
|
||||
# check that min length is applied at length 5
|
||||
input_ids = ids_tensor((batch_size, 20), vocab_size=20)
|
||||
cur_len = 5
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
scores_before_min_length = min_dist_processor(input_ids, scores, cur_len=cur_len)
|
||||
self.assertListEqual(scores_before_min_length[:, eos_token_id].tolist(), 4 * [-float("inf")])
|
||||
|
||||
# check that min length is not applied anymore at length 15
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
cur_len = 15
|
||||
scores_before_min_length = min_dist_processor(input_ids, scores, cur_len=cur_len)
|
||||
self.assertFalse(jnp.isinf(scores_before_min_length).any())
|
||||
|
||||
def test_forced_bos_token_logits_processor(self):
|
||||
vocab_size = 20
|
||||
batch_size = 4
|
||||
bos_token_id = 0
|
||||
|
||||
logits_processor = FlaxForcedBOSTokenLogitsProcessor(bos_token_id=bos_token_id)
|
||||
|
||||
# check that all scores are -inf except the bos_token_id score
|
||||
input_ids = ids_tensor((batch_size, 1), vocab_size=20)
|
||||
cur_len = 1
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
scores = logits_processor(input_ids, scores, cur_len=cur_len)
|
||||
self.assertTrue(jnp.isneginf(scores[:, bos_token_id + 1 :]).all())
|
||||
self.assertListEqual(scores[:, bos_token_id].tolist(), 4 * [0]) # score for bos_token_id shold be zero
|
||||
|
||||
# check that bos_token_id is not forced if current length is greater than 1
|
||||
cur_len = 3
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
scores = logits_processor(input_ids, scores, cur_len=cur_len)
|
||||
self.assertFalse(jnp.isinf(scores).any())
|
||||
|
||||
def test_forced_eos_token_logits_processor(self):
|
||||
vocab_size = 20
|
||||
batch_size = 4
|
||||
eos_token_id = 0
|
||||
max_length = 5
|
||||
|
||||
logits_processor = FlaxForcedEOSTokenLogitsProcessor(max_length=max_length, eos_token_id=eos_token_id)
|
||||
|
||||
# check that all scores are -inf except the eos_token_id when max_length is reached
|
||||
input_ids = ids_tensor((batch_size, 4), vocab_size=20)
|
||||
cur_len = 4
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
scores = logits_processor(input_ids, scores, cur_len=cur_len)
|
||||
self.assertTrue(jnp.isneginf(scores[:, eos_token_id + 1 :]).all())
|
||||
self.assertListEqual(scores[:, eos_token_id].tolist(), 4 * [0]) # score for eos_token_id should be zero
|
||||
|
||||
# check that eos_token_id is not forced if max_length is not reached
|
||||
cur_len = 3
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
scores = logits_processor(input_ids, scores, cur_len=cur_len)
|
||||
self.assertFalse(jnp.isinf(scores).any())
|
||||
|
||||
def test_processor_list(self):
|
||||
batch_size = 4
|
||||
sequence_length = 10
|
||||
vocab_size = 15
|
||||
eos_token_id = 2
|
||||
bos_token_id = 1
|
||||
max_length = 15
|
||||
|
||||
# dummy input_ids and scores
|
||||
input_ids = ids_tensor((batch_size, sequence_length), vocab_size)
|
||||
@ -147,14 +216,83 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
top_k_warp = FlaxTopKLogitsWarper(3)
|
||||
top_p_warp = FlaxTopPLogitsWarper(0.8)
|
||||
|
||||
# instantiate all logits processors
|
||||
min_dist_proc = FlaxMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
|
||||
bos_dist_proc = FlaxForcedBOSTokenLogitsProcessor(bos_token_id=bos_token_id)
|
||||
eos_dist_proc = FlaxForcedEOSTokenLogitsProcessor(max_length=max_length, eos_token_id=eos_token_id)
|
||||
|
||||
cur_len = 10
|
||||
|
||||
# no processor list
|
||||
scores = temp_dist_warp(input_ids, scores)
|
||||
scores = top_k_warp(input_ids, scores)
|
||||
scores = top_p_warp(input_ids, scores)
|
||||
scores = temp_dist_warp(input_ids, scores, cur_len=cur_len)
|
||||
scores = top_k_warp(input_ids, scores, cur_len=cur_len)
|
||||
scores = top_p_warp(input_ids, scores, cur_len=cur_len)
|
||||
scores = min_dist_proc(input_ids, scores, cur_len=cur_len)
|
||||
scores = bos_dist_proc(input_ids, scores, cur_len=cur_len)
|
||||
scores = eos_dist_proc(input_ids, scores, cur_len=cur_len)
|
||||
|
||||
# with processor list
|
||||
processor = FlaxLogitsProcessorList([temp_dist_warp, top_k_warp, top_p_warp])
|
||||
scores_comp = processor(input_ids, scores_comp)
|
||||
processor = FlaxLogitsProcessorList(
|
||||
[temp_dist_warp, top_k_warp, top_p_warp, min_dist_proc, bos_dist_proc, eos_dist_proc]
|
||||
)
|
||||
scores_comp = processor(input_ids, scores_comp, cur_len=cur_len)
|
||||
|
||||
# scores should be equal
|
||||
self.assertTrue(jnp.allclose(scores, scores_comp, atol=1e-3))
|
||||
|
||||
# input_ids should never be changed
|
||||
self.assertListEqual(input_ids.tolist(), input_ids_comp.tolist())
|
||||
|
||||
def test_processor_list_jitted(self):
|
||||
batch_size = 4
|
||||
sequence_length = 10
|
||||
vocab_size = 15
|
||||
eos_token_id = 2
|
||||
bos_token_id = 1
|
||||
max_length = 15
|
||||
|
||||
# dummy input_ids and scores
|
||||
input_ids = ids_tensor((batch_size, sequence_length), vocab_size)
|
||||
input_ids_comp = input_ids.copy()
|
||||
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
scores_comp = scores.copy()
|
||||
|
||||
# instantiate all dist processors
|
||||
temp_dist_warp = FlaxTemperatureLogitsWarper(temperature=0.5)
|
||||
top_k_warp = FlaxTopKLogitsWarper(3)
|
||||
top_p_warp = FlaxTopPLogitsWarper(0.8)
|
||||
|
||||
# instantiate all logits processors
|
||||
min_dist_proc = FlaxMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
|
||||
bos_dist_proc = FlaxForcedBOSTokenLogitsProcessor(bos_token_id=bos_token_id)
|
||||
eos_dist_proc = FlaxForcedEOSTokenLogitsProcessor(max_length=max_length, eos_token_id=eos_token_id)
|
||||
|
||||
cur_len = 10
|
||||
|
||||
# no processor list
|
||||
def run_no_processor_list(input_ids, scores, cur_len):
|
||||
scores = temp_dist_warp(input_ids, scores, cur_len=cur_len)
|
||||
scores = top_k_warp(input_ids, scores, cur_len=cur_len)
|
||||
scores = top_p_warp(input_ids, scores, cur_len=cur_len)
|
||||
scores = min_dist_proc(input_ids, scores, cur_len=cur_len)
|
||||
scores = bos_dist_proc(input_ids, scores, cur_len=cur_len)
|
||||
scores = eos_dist_proc(input_ids, scores, cur_len=cur_len)
|
||||
return scores
|
||||
|
||||
# with processor list
|
||||
def run_processor_list(input_ids, scores, cur_len):
|
||||
processor = FlaxLogitsProcessorList(
|
||||
[temp_dist_warp, top_k_warp, top_p_warp, min_dist_proc, bos_dist_proc, eos_dist_proc]
|
||||
)
|
||||
scores = processor(input_ids, scores, cur_len=cur_len)
|
||||
return scores
|
||||
|
||||
jitted_run_no_processor_list = jax.jit(run_no_processor_list)
|
||||
jitted_run_processor_list = jax.jit(run_processor_list)
|
||||
|
||||
scores = jitted_run_no_processor_list(input_ids, scores, cur_len)
|
||||
scores_comp = jitted_run_processor_list(input_ids, scores_comp, cur_len)
|
||||
|
||||
# scores should be equal
|
||||
self.assertTrue(jnp.allclose(scores, scores_comp, atol=1e-3))
|
||||
|
@ -110,6 +110,23 @@ class FlaxGenerationTesterMixin:
|
||||
|
||||
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
|
||||
|
||||
def test_beam_search_generate(self):
|
||||
config, input_ids, _, max_length = self._get_input_ids_and_config()
|
||||
config.do_sample = False
|
||||
config.max_length = max_length
|
||||
config.num_beams = 2
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
generation_outputs = model.generate(input_ids).sequences
|
||||
self.assertEqual(generation_outputs.shape[-1], max_length)
|
||||
|
||||
jit_generate = jit(model.generate)
|
||||
jit_generation_outputs = jit_generate(input_ids).sequences
|
||||
|
||||
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
|
||||
|
||||
def test_sample_generate_logits_warper(self):
|
||||
config, input_ids, _, max_length = self._get_input_ids_and_config()
|
||||
config.do_sample = True
|
||||
@ -117,6 +134,46 @@ class FlaxGenerationTesterMixin:
|
||||
config.temperature = 0.8
|
||||
config.top_k = 10
|
||||
config.top_p = 0.3
|
||||
config.min_length = 1
|
||||
config.forced_bos_token_id = 8
|
||||
config.forced_eos_token_id = 9
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
generation_outputs = model.generate(input_ids).sequences
|
||||
self.assertEqual(generation_outputs.shape[-1], max_length)
|
||||
|
||||
jit_generate = jit(model.generate)
|
||||
jit_generation_outputs = jit_generate(input_ids).sequences
|
||||
|
||||
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
|
||||
|
||||
def test_greedy_generate_logits_warper(self):
|
||||
config, input_ids, _, max_length = self._get_input_ids_and_config()
|
||||
config.max_length = max_length
|
||||
config.min_length = 1
|
||||
config.forced_bos_token_id = 8
|
||||
config.forced_eos_token_id = 9
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
generation_outputs = model.generate(input_ids).sequences
|
||||
self.assertEqual(generation_outputs.shape[-1], max_length)
|
||||
|
||||
jit_generate = jit(model.generate)
|
||||
jit_generation_outputs = jit_generate(input_ids).sequences
|
||||
|
||||
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
|
||||
|
||||
def test_beam_search_generate_logits_warper(self):
|
||||
config, input_ids, _, max_length = self._get_input_ids_and_config()
|
||||
config.max_length = max_length
|
||||
config.num_beams = 2
|
||||
config.min_length = 1
|
||||
config.forced_bos_token_id = 8
|
||||
config.forced_eos_token_id = 9
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config)
|
||||
@ -168,3 +225,23 @@ class FlaxGenerationTesterMixin:
|
||||
jit_generation_outputs = jit_generate(input_ids, attention_mask=attention_mask).sequences
|
||||
|
||||
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
|
||||
|
||||
def test_beam_search_generate_attn_mask(self):
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
|
||||
# pad attention mask on the left
|
||||
attention_mask = jax.ops.index_update(attention_mask, (0, 0), 0)
|
||||
|
||||
config.num_beams = 2
|
||||
config.max_length = max_length
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
generation_outputs = model.generate(input_ids, attention_mask=attention_mask).sequences
|
||||
self.assertEqual(generation_outputs.shape[-1], max_length)
|
||||
|
||||
jit_generate = jit(model.generate)
|
||||
jit_generation_outputs = jit_generate(input_ids, attention_mask=attention_mask).sequences
|
||||
|
||||
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
|
||||
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user