[Flax] Add Beam Search (#12131)

* fix_torch_device_generate_test

* remove @

* push new logit processors

* add processors

* save first working version

* save intermediate

* finish

* make style

* make fix-copies

* finish

* Update tests/test_modeling_flax_bart.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Suraj Patil <surajp815@gmail.com>

Co-authored-by: Patrick von Platen <patrick@huggingface.co>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
Patrick von Platen 2021-06-16 09:43:54 +01:00 committed by GitHub
parent 802ffaff0d
commit c3c39f7e84
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 840 additions and 44 deletions

View File

@ -186,6 +186,15 @@ generation.
.. autoclass:: transformers.FlaxTopKLogitsWarper
:members: __call__
.. autoclass:: transformers.FlaxForcedBOSTokenLogitsProcessor
:members: __call__
.. autoclass:: transformers.FlaxForcedEOSTokenLogitsProcessor
:members: __call__
.. autoclass:: transformers.FlaxMinLengthLogitsProcessor
:members: __call__
StoppingCriteria
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -1486,9 +1486,12 @@ else:
# FLAX-backed objects
if is_flax_available():
_import_structure["generation_flax_logits_process"] = [
"FlaxForcedBOSTokenLogitsProcessor",
"FlaxForcedEOSTokenLogitsProcessor",
"FlaxLogitsProcessor",
"FlaxLogitsProcessorList",
"FlaxLogitsWarper",
"FlaxMinLengthLogitsProcessor",
"FlaxTemperatureLogitsWarper",
"FlaxTopKLogitsWarper",
"FlaxTopPLogitsWarper",
@ -2814,9 +2817,12 @@ if TYPE_CHECKING:
if is_flax_available():
from .generation_flax_logits_process import (
FlaxForcedBOSTokenLogitsProcessor,
FlaxForcedEOSTokenLogitsProcessor,
FlaxLogitsProcessor,
FlaxLogitsProcessorList,
FlaxLogitsWarper,
FlaxMinLengthLogitsProcessor,
FlaxTemperatureLogitsWarper,
FlaxTopKLogitsWarper,
FlaxTopPLogitsWarper,

View File

@ -81,16 +81,18 @@ class FlaxLogitsProcessorList(list):
"""
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, **kwargs) -> jax_xla.DeviceArray:
def __call__(
self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int, **kwargs
) -> jax_xla.DeviceArray:
for processor in self:
function_args = inspect.signature(processor.__call__).parameters
if len(function_args) > 2:
if len(function_args) > 3:
assert all(
arg in kwargs for arg in list(function_args.keys())[2:]
), f"Make sure that all the required parameters: {list(function_args.keys())} for {processor.__class__} are passed to the logits processor."
scores = processor(input_ids, scores, **kwargs)
scores = processor(input_ids, scores, cur_len, **kwargs)
else:
scores = processor(input_ids, scores)
scores = processor(input_ids, scores, cur_len)
return scores
@ -109,7 +111,9 @@ class FlaxTemperatureLogitsWarper(FlaxLogitsWarper):
self.temperature = temperature
def __call__(self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray) -> jax_xla.DeviceArray:
def __call__(
self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int
) -> jax_xla.DeviceArray:
scores = scores / self.temperature
return scores
@ -137,7 +141,9 @@ class FlaxTopPLogitsWarper(FlaxLogitsWarper):
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray) -> jax_xla.DeviceArray:
def __call__(
self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int
) -> jax_xla.DeviceArray:
topk_scores, topk_indices = lax.top_k(scores, scores.shape[-1])
mask_scores = jnp.full_like(scores, self.filter_value)
@ -177,7 +183,9 @@ class FlaxTopKLogitsWarper(FlaxLogitsWarper):
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray) -> jax_xla.DeviceArray:
def __call__(
self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int
) -> jax_xla.DeviceArray:
batch_size, vocab_size = scores.shape
next_scores_flat = jnp.full(batch_size * vocab_size, self.filter_value)
@ -190,3 +198,94 @@ class FlaxTopKLogitsWarper(FlaxLogitsWarper):
next_scores_flat = jax.ops.index_update(next_scores_flat, topk_indices_flat, topk_scores_flat)
next_scores = next_scores_flat.reshape(batch_size, vocab_size)
return next_scores
class FlaxForcedBOSTokenLogitsProcessor(FlaxLogitsProcessor):
r"""
:class:`~transformers.FlaxLogitsProcessor` that enforces the specified token as the first generated token.
Args:
bos_token_id (:obj:`int`):
The id of the token to force as the first generated token.
"""
def __init__(self, bos_token_id: int):
self.bos_token_id = bos_token_id
def __call__(
self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int
) -> jax_xla.DeviceArray:
new_scores = jnp.full(scores.shape, -float("inf"))
apply_penalty = 1 - jnp.bool_(cur_len - 1)
scores = jnp.where(
apply_penalty, jax.ops.index_update(new_scores, jax.ops.index[:, self.bos_token_id], 0), scores
)
return scores
class FlaxForcedEOSTokenLogitsProcessor(FlaxLogitsProcessor):
r"""
:class:`~transformers.FlaxLogitsProcessor` that enforces the specified token as the last generated token when
:obj:`max_length` is reached.
Args:
max_length (:obj:`int`):
The maximum length of the sequence to be generated.
eos_token_id (:obj:`int`):
The id of the token to force as the last generated token when :obj:`max_length` is reached.
"""
def __init__(self, max_length: int, eos_token_id: int):
self.max_length = max_length
self.eos_token_id = eos_token_id
def __call__(
self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int
) -> jax_xla.DeviceArray:
new_scores = jnp.full(scores.shape, -float("inf"))
apply_penalty = 1 - jnp.bool_(cur_len - self.max_length + 1)
scores = jnp.where(
apply_penalty, jax.ops.index_update(new_scores, jax.ops.index[:, self.eos_token_id], 0), scores
)
return scores
class FlaxMinLengthLogitsProcessor(FlaxLogitsProcessor):
r"""
:class:`transformers.FlaxLogitsProcessor` enforcing a min-length by setting EOS probability to 0.
Args:
min_length (:obj:`int`):
The minimum length below which the score of :obj:`eos_token_id` is set to :obj:`-float("Inf")`.
eos_token_id (:obj:`int`):
The id of the `end-of-sequence` token.
"""
def __init__(self, min_length: int, eos_token_id: int):
if not isinstance(min_length, int) or min_length < 0:
raise ValueError(f"`min_length` has to be a positive integer, but is {min_length}")
if not isinstance(eos_token_id, int) or eos_token_id < 0:
raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}")
self.min_length = min_length
self.eos_token_id = eos_token_id
def __call__(
self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int
) -> jax_xla.DeviceArray:
# create boolean flag to decide if min length penalty should be applied
apply_penalty = 1 - jnp.clip(cur_len - self.min_length, 0, 1)
scores = jnp.where(
apply_penalty, jax.ops.index_update(scores, jax.ops.index[:, self.eos_token_id], -float("inf")), scores
)
return scores

View File

@ -17,6 +17,8 @@
from typing import Dict, Optional
import numpy as np
import flax
import jax
import jax.numpy as jnp
@ -25,7 +27,10 @@ from jax import lax
from .file_utils import ModelOutput
from .generation_flax_logits_process import (
FlaxForcedBOSTokenLogitsProcessor,
FlaxForcedEOSTokenLogitsProcessor,
FlaxLogitsProcessorList,
FlaxMinLengthLogitsProcessor,
FlaxTemperatureLogitsWarper,
FlaxTopKLogitsWarper,
FlaxTopPLogitsWarper,
@ -43,9 +48,8 @@ class FlaxGreedySearchOutput(ModelOutput):
Args:
sequences (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
The generated sequences. If all batches finished early due to the :obj:`eos_token_id`, :obj:`sequences` is
padded to :obj:`max_length`.
sequences (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, max_length)`):
The generated sequences.
"""
sequences: jax_xla.DeviceArray = None
@ -58,19 +62,35 @@ class FlaxSampleOutput(ModelOutput):
Args:
sequences (:obj:`torch.LongTensor` of shape :obj:`(batch_size, max_length)`):
The generated sequences. If all batches finished early due to the :obj:`eos_token_id`, :obj:`sequences` is
padded to :obj:`max_length`.
sequences (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, max_length)`):
The generated sequences.
"""
sequences: jax_xla.DeviceArray = None
@flax.struct.dataclass
class FlaxBeamSearchOutput(ModelOutput):
"""
Flax Base class for outputs of decoder-only generation models using greedy search.
Args:
sequences (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, max_length)`):
The generated sequences.
scores (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size,)`):
The scores (log probabilites) of the generated sequences.
"""
sequences: jax_xla.DeviceArray = None
scores: jax_xla.DeviceArray = None
@flax.struct.dataclass
class GreedyState:
cur_len: jax_xla.DeviceArray
sequences: jax_xla.DeviceArray
current_token: jax_xla.DeviceArray
running_token: jax_xla.DeviceArray
is_sent_finished: jax_xla.DeviceArray
model_kwargs: Dict[str, jax_xla.DeviceArray]
@ -79,12 +99,23 @@ class GreedyState:
class SampleState:
cur_len: jax_xla.DeviceArray
sequences: jax_xla.DeviceArray
current_token: jax_xla.DeviceArray
running_token: jax_xla.DeviceArray
is_sent_finished: jax_xla.DeviceArray
prng_key: jax_xla.DeviceArray
model_kwargs: Dict[str, jax_xla.DeviceArray]
@flax.struct.dataclass
class BeamSearchState:
cur_len: jax_xla.DeviceArray
running_sequences: jax_xla.DeviceArray
running_scores: jax_xla.DeviceArray
sequences: jax_xla.DeviceArray
scores: jax_xla.DeviceArray
is_sent_finished: jax_xla.DeviceArray
model_kwargs: Dict[str, jax_xla.DeviceArray]
class FlaxGenerationMixin:
"""
A class containing all of the functions supporting generation, to be used as a mixin in
@ -110,6 +141,10 @@ class FlaxGenerationMixin:
model_kwargs["encoder_outputs"] = self.encode(input_ids, return_dict=True, **encoder_kwargs)
return model_kwargs
@staticmethod
def _expand_to_num_beams(tensor, num_beams):
return jnp.broadcast_to(tensor[:, None], (tensor.shape[0], num_beams) + tensor.shape[1:])
def generate(
self,
input_ids: jax_xla.DeviceArray,
@ -123,6 +158,13 @@ class FlaxGenerationMixin:
top_k: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None,
num_beams: Optional[int] = None,
no_repeat_ngram_size: Optional[int] = None,
min_length: Optional[int] = None,
forced_bos_token_id: Optional[int] = None,
forced_eos_token_id: Optional[int] = None,
length_penalty: Optional[float] = None,
early_stopping: Optional[bool] = None,
trace: bool = True,
params: Optional[Dict[str, jax_xla.DeviceArray]] = None,
**model_kwargs,
@ -159,6 +201,8 @@ class FlaxGenerationMixin:
The id of the `beginning-of-sequence` token.
eos_token_id (:obj:`int`, `optional`):
The id of the `end-of-sequence` token.
num_beams (:obj:`int`, `optional`, defaults to 1):
Number of beams for beam search. 1 means no beam search.
decoder_start_token_id (:obj:`int`, `optional`):
If an encoder-decoder model starts decoding with a different token than `bos`, the id of that token.
trace (:obj:`bool`, `optional`, defaults to :obj:`True`):
@ -204,9 +248,27 @@ class FlaxGenerationMixin:
input_ids = jnp.ones((input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
do_sample = do_sample if do_sample is not None else self.config.do_sample
num_beams = num_beams if num_beams is not None else self.config.num_beams
if do_sample:
if not do_sample and num_beams == 1:
logits_processor = self._get_logits_processor(
no_repeat_ngram_size, min_length, max_length, eos_token_id, forced_bos_token_id, forced_eos_token_id
)
return self._greedy_search(
input_ids,
max_length,
pad_token_id,
eos_token_id,
logits_processor=logits_processor,
trace=trace,
params=params,
model_kwargs=model_kwargs,
)
elif do_sample and num_beams == 1:
logits_warper = self._get_logits_warper(top_k=top_k, top_p=top_p, temperature=temperature)
logits_processor = self._get_logits_processor(
no_repeat_ngram_size, min_length, max_length, eos_token_id, forced_bos_token_id, forced_eos_token_id
)
return self._sample(
input_ids,
max_length,
@ -214,20 +276,43 @@ class FlaxGenerationMixin:
eos_token_id,
prng_key,
logits_warper=logits_warper,
logits_processor=logits_processor,
trace=trace,
params=params,
model_kwargs=model_kwargs,
)
elif not do_sample and num_beams > 1:
# broadcast input_ids & encoder_outputs
input_ids = self._expand_to_num_beams(input_ids, num_beams=num_beams)
if "encoder_outputs" in model_kwargs:
model_kwargs["encoder_outputs"]["last_hidden_state"] = self._expand_to_num_beams(
model_kwargs["encoder_outputs"]["last_hidden_state"], num_beams=num_beams
)
if "attention_mask" in model_kwargs:
model_kwargs["attention_mask"] = self._expand_to_num_beams(
model_kwargs["attention_mask"], num_beams=num_beams
)
logits_processor = self._get_logits_processor(
no_repeat_ngram_size, min_length, max_length, eos_token_id, forced_bos_token_id, forced_eos_token_id
)
return self._beam_search(
input_ids,
max_length,
pad_token_id,
eos_token_id,
length_penalty=length_penalty,
early_stopping=early_stopping,
logits_processor=logits_processor,
trace=trace,
params=params,
model_kwargs=model_kwargs,
)
else:
return self._greedy_search(
input_ids,
max_length,
pad_token_id,
eos_token_id,
trace=trace,
params=params,
model_kwargs=model_kwargs,
)
raise NotImplementedError("`Beam sampling is currently not implemented.")
def _get_logits_warper(
self, top_k: int = None, top_p: float = None, temperature: float = None
@ -255,12 +340,51 @@ class FlaxGenerationMixin:
return warpers
def _get_logits_processor(
self,
no_repeat_ngram_size: int,
min_length: int,
max_length: int,
eos_token_id: int,
forced_bos_token_id: int,
forced_eos_token_id: int,
) -> FlaxLogitsProcessorList:
"""
This class returns a :obj:`~transformers.FlaxLogitsProcessorList` list object that contains all relevant
:obj:`~transformers.FlaxLogitsProcessor` instances used to modify the scores of the language model head.
"""
processors = FlaxLogitsProcessorList()
# init warp parameters
no_repeat_ngram_size = (
no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
)
min_length = min_length if min_length is not None else self.config.min_length
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
forced_bos_token_id = (
forced_bos_token_id if forced_bos_token_id is not None else self.config.forced_bos_token_id
)
forced_eos_token_id = (
forced_eos_token_id if forced_eos_token_id is not None else self.config.forced_eos_token_id
)
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# all samplers can be found in `generation_utils_samplers.py`
if min_length is not None and eos_token_id is not None and min_length > -1:
processors.append(FlaxMinLengthLogitsProcessor(min_length, eos_token_id))
if forced_bos_token_id is not None:
processors.append(FlaxForcedBOSTokenLogitsProcessor(forced_bos_token_id))
if forced_eos_token_id is not None:
processors.append(FlaxForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id))
return processors
def _greedy_search(
self,
input_ids: None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
logits_processor: Optional[FlaxLogitsProcessorList] = None,
trace: bool = True,
params: Optional[Dict[str, jax_xla.DeviceArray]] = None,
model_kwargs: Optional[Dict[str, jax_xla.DeviceArray]] = None,
@ -293,7 +417,7 @@ class FlaxGenerationMixin:
state = GreedyState(
cur_len=cur_len,
sequences=sequences,
current_token=input_ids,
running_token=input_ids,
is_sent_finished=is_sent_finished,
model_kwargs=model_kwargs,
)
@ -307,8 +431,13 @@ class FlaxGenerationMixin:
def greedy_search_body_fn(state):
"""state update fn."""
model_outputs = model(state.current_token, params=params, **state.model_kwargs)
next_token = jnp.argmax(model_outputs.logits[:, -1], axis=-1)
model_outputs = model(state.running_token, params=params, **state.model_kwargs)
logits = model_outputs.logits[:, -1]
# apply min_length, ...
logits = logits_processor(state.sequences, logits, state.cur_len)
next_token = jnp.argmax(logits, axis=-1)
next_is_sent_finished = state.is_sent_finished | (next_token == eos_token_id)
next_token = next_token * ~next_is_sent_finished + pad_token_id * next_is_sent_finished
@ -319,7 +448,7 @@ class FlaxGenerationMixin:
return GreedyState(
cur_len=state.cur_len + 1,
sequences=next_sequences,
current_token=next_token,
running_token=next_token,
is_sent_finished=next_is_sent_finished,
model_kwargs=next_model_kwargs,
)
@ -342,6 +471,7 @@ class FlaxGenerationMixin:
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
prng_key: Optional[jax_xla.DeviceArray] = None,
logits_processor: Optional[FlaxLogitsProcessorList] = None,
logits_warper: Optional[FlaxLogitsProcessorList] = None,
trace: bool = True,
params: Optional[Dict[str, jax_xla.DeviceArray]] = None,
@ -377,7 +507,7 @@ class FlaxGenerationMixin:
state = SampleState(
cur_len=cur_len,
sequences=sequences,
current_token=input_ids,
running_token=input_ids,
is_sent_finished=is_sent_finished,
prng_key=prng_key,
model_kwargs=model_kwargs,
@ -393,12 +523,14 @@ class FlaxGenerationMixin:
def sample_search_body_fn(state):
"""state update fn."""
prng_key, prng_key_next = jax.random.split(state.prng_key)
model_outputs = model(state.current_token, params=params, **state.model_kwargs)
model_outputs = model(state.running_token, params=params, **state.model_kwargs)
logits = model_outputs.logits[:, -1]
# apply min_length, ...
logits = logits_processor(state.sequences, logits, state.cur_len)
# apply top_k, top_k, temperature
logits = logits_warper(state.sequences, logits)
logits = logits_warper(logits, logits, state.cur_len)
next_token = jax.random.categorical(prng_key, model_outputs.logits[:, -1], axis=-1)
@ -412,7 +544,7 @@ class FlaxGenerationMixin:
return SampleState(
cur_len=state.cur_len + 1,
sequences=next_sequences,
current_token=next_token,
running_token=next_token,
is_sent_finished=next_is_sent_finished,
model_kwargs=next_model_kwargs,
prng_key=prng_key_next,
@ -428,3 +560,251 @@ class FlaxGenerationMixin:
state = lax.while_loop(sample_search_cond_fn, sample_search_body_fn, state)
return FlaxSampleOutput(sequences=state.sequences)
def _beam_search(
self,
input_ids: None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
length_penalty: Optional[float] = None,
early_stopping: Optional[bool] = None,
logits_processor: Optional[FlaxLogitsProcessorList] = None,
trace: bool = True,
params: Optional[Dict[str, jax_xla.DeviceArray]] = None,
model_kwargs: Optional[Dict[str, jax_xla.DeviceArray]] = None,
):
"""
This beam search function is heavily inspired by Flax's official example:
https://github.com/google/flax/blob/master/examples/wmt/train.py#L254
"""
def flatten_beam_dim(tensor):
"""Flattens the first two dimensions of a non-scalar array."""
# ignore scalars (e.g. cache index)
if tensor.ndim == 0:
return tensor
return tensor.reshape((tensor.shape[0] * tensor.shape[1],) + tensor.shape[2:])
def unflatten_beam_dim(tensor, batch_size, num_beams):
"""Unflattens the first, flat batch*beam dimension of a non-scalar array."""
# ignore scalars (e.g. cache index)
if tensor.ndim == 0:
return tensor
return tensor.reshape((batch_size, num_beams) + tensor.shape[1:])
def gather_beams(nested, beam_indices, batch_size, new_num_beams):
"""
Gathers the beam slices indexed by beam_indices into new beam array.
"""
batch_indices = jnp.reshape(
jnp.arange(batch_size * new_num_beams) // new_num_beams, (batch_size, new_num_beams)
)
def gather_fn(tensor):
# ignore scalars (e.g. cache index)
if tensor.ndim == 0:
return tensor
else:
return tensor[batch_indices, beam_indices]
return jax.tree_map(gather_fn, nested)
# init values
max_length = max_length if max_length is not None else self.config.max_length
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
batch_size, num_beams, cur_len = input_ids.shape
eos_token_id = jnp.array(eos_token_id)
pad_token_id = jnp.array(pad_token_id)
cur_len = jnp.array(cur_len)
# per batch,beam-item holding current token in loop.
sequences = jnp.full((batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32)
running_sequences = jnp.full((batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32)
running_sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0, 0))
# per batch,beam-item state bit indicating if sentence has finished.
is_sent_finished = jnp.zeros((batch_size, num_beams), dtype=jnp.bool_)
# per batch,beam-item score, logprobs
running_scores = jnp.tile(jnp.array([0.0] + [np.array(-1.0e7)] * (num_beams - 1)), [batch_size, 1])
scores = jnp.ones((batch_size, num_beams)) * np.array(-1.0e7)
# For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop
# and pass it the `encoder_outputs`, which are part of the `model_kwargs`.
model = self.decode if self.config.is_encoder_decoder else self
# flatten beam dim
if "encoder_outputs" in model_kwargs:
model_kwargs["encoder_outputs"]["last_hidden_state"] = flatten_beam_dim(
model_kwargs["encoder_outputs"]["last_hidden_state"]
)
if "attention_mask" in model_kwargs:
model_kwargs["attention_mask"] = flatten_beam_dim(model_kwargs["attention_mask"])
# initialize model specific kwargs
model_kwargs = self.prepare_inputs_for_generation(flatten_beam_dim(input_ids), max_length, **model_kwargs)
# initialize state
state = BeamSearchState(
cur_len=cur_len,
running_sequences=running_sequences,
running_scores=running_scores,
sequences=sequences,
scores=scores,
is_sent_finished=is_sent_finished,
model_kwargs=model_kwargs,
)
def beam_search_cond_fn(state):
"""beam search state termination condition fn."""
# 1. is less than max length?
not_max_length_yet = state.cur_len < max_length
# 2. can the new beams still improve?
best_running_score = state.running_scores[:, -1:] / (max_length ** length_penalty)
worst_finished_score = jnp.where(
state.is_sent_finished, jnp.min(state.scores, axis=1, keepdims=True), np.array(-1.0e7)
)
improvement_still_possible = jnp.all(worst_finished_score < best_running_score)
# 3. is there still a beam that has not finished?
still_open_beam = ~(jnp.all(state.is_sent_finished) & early_stopping)
return not_max_length_yet & still_open_beam & improvement_still_possible
def beam_search_body_fn(state):
"""beam search state update fn."""
# 1. Forward current tokens
# Collect the current position slice along length to feed the fast
# autoregressive decoder model. Flatten the beam dimension into batch
# dimension for feeding into the model.
# unflatten beam dimension
# Unflatten beam dimension in attention cache arrays
input_token = flatten_beam_dim(
lax.dynamic_slice(state.running_sequences, (0, 0, state.cur_len - 1), (batch_size, num_beams, 1))
)
model_outputs = model(input_token, params=params, **state.model_kwargs)
logits = unflatten_beam_dim(model_outputs.logits[:, 0], batch_size, num_beams)
cache = jax.tree_map(
lambda tensor: unflatten_beam_dim(tensor, batch_size, num_beams), model_outputs.past_key_values
)
# 2. Compute log probs
# get log probabilities from logits,
# process logits with processors (*e.g.* min_length, ...), and
# add new logprobs to existing running logprobs scores.
log_probs = jax.nn.log_softmax(logits)
log_probs = logits_processor(
flatten_beam_dim(running_sequences), flatten_beam_dim(log_probs), state.cur_len
)
log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams)
log_probs = log_probs + jnp.expand_dims(state.running_scores, axis=2)
vocab_size = log_probs.shape[2]
log_probs = log_probs.reshape((batch_size, num_beams * vocab_size))
# 3. Retrieve top-K
# Each item in batch has num_beams * vocab_size candidate sequences.
# For each item, get the top 2*k candidates with the highest log-
# probabilities. We gather the top 2*K beams here so that even if the best
# K sequences reach EOS simultaneously, we have another K sequences
# remaining to continue the live beam search.
# Gather the top 2*K scores from _all_ beams.
# Gather 2*k top beams.
# Recover the beam index by floor division.
# Recover token id by modulo division and expand Id array for broadcasting.
# Update sequences for the 2*K top-k new sequences.
beams_to_keep = 2 * num_beams
topk_log_probs, topk_indices = lax.top_k(log_probs, k=beams_to_keep)
topk_beam_indices = topk_indices // vocab_size
topk_running_sequences = gather_beams(
state.running_sequences, topk_beam_indices, batch_size, beams_to_keep
)
topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2)
topk_sequences = lax.dynamic_update_slice(topk_running_sequences, topk_ids, (0, 0, state.cur_len))
# 4. Check which sequences have ended
# Update current sequences:
# Did any of these sequences reach an end marker?
# To prevent these just finished sequences from being added to the current sequences
# set of active beam search sequences, set their log probs to a very large
# negative value.
did_topk_just_finished = topk_sequences[:, :, state.cur_len] == eos_token_id
topk_log_probs = topk_log_probs + did_topk_just_finished * np.array(-1.0e7)
# 5. Get running sequences scores for next
# Determine the top k beam indices (from top 2*k beams) from log probs
# and gather top k beams (from top 2*k beams).
next_topk_indices = jnp.flip(lax.top_k(topk_log_probs, k=num_beams)[1], axis=1)
next_running_sequences, next_running_scores = gather_beams(
[topk_sequences, topk_log_probs], next_topk_indices, batch_size, num_beams
)
# 6. Process topk logits
# Further process log probs:
# - add length penalty
# - make sure no scores can be added anymore if beam is full
# - make sure still running sequences cannot be chosen as finalized beam
topk_log_probs = topk_log_probs / (state.cur_len ** length_penalty)
beams_in_batch_are_full = (
jnp.broadcast_to(state.is_sent_finished.all(axis=-1, keepdims=True), did_topk_just_finished.shape)
& early_stopping
)
add_penalty = ~did_topk_just_finished | beams_in_batch_are_full
topk_log_probs += add_penalty * np.array(-1.0e7)
# 7. Get scores, sequences, is sentence finished for next.
# Combine sequences, scores, and flags along the beam dimension and compare
# new finished sequence scores to existing finished scores and select the
# best from the new set of beams
merged_sequences = jnp.concatenate([state.sequences, topk_sequences], axis=1)
merged_scores = jnp.concatenate([state.scores, topk_log_probs], axis=1)
merged_is_sent_finished = jnp.concatenate([state.is_sent_finished, did_topk_just_finished], axis=1)
topk_merged_indices = jnp.flip(lax.top_k(merged_scores, k=num_beams)[1], axis=1)
next_sequences, next_scores, next_is_sent_finished = gather_beams(
[merged_sequences, merged_scores, merged_is_sent_finished], topk_merged_indices, batch_size, num_beams
)
# 8. Update model kwargs.
# Determine the top k beam indices from the original set of all beams.
# With these, gather the top k beam-associated caches.
next_running_indices = gather_beams(topk_beam_indices, next_topk_indices, batch_size, num_beams)
next_cache = gather_beams(cache, next_running_indices, batch_size, num_beams)
model_outputs["past_key_values"] = jax.tree_map(lambda x: flatten_beam_dim(x), next_cache)
next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs)
return BeamSearchState(
cur_len=state.cur_len + 1,
running_scores=next_running_scores,
running_sequences=next_running_sequences,
scores=next_scores,
sequences=next_sequences,
is_sent_finished=next_is_sent_finished,
model_kwargs=next_model_kwargs,
)
# The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
state = beam_search_body_fn(state)
if not trace:
state = self._run_loop_in_debug(beam_search_cond_fn, beam_search_body_fn, state)
else:
state = lax.while_loop(beam_search_cond_fn, beam_search_body_fn, state)
# Account for the edge-case where there are no finished sequences for a
# particular batch item. If so, return running sequences for that batch item.
none_finished = jnp.any(state.is_sent_finished, axis=1)
sequences = jnp.where(none_finished[:, None, None], state.sequences, state.running_sequences)
scores = jnp.where(none_finished[:, None], state.scores, state.running_scores)
# take best beam for each batch
sequences = sequences[:, -1]
scores = scores[:, -1]
return FlaxBeamSearchOutput(sequences=sequences, scores=scores)

View File

@ -2,6 +2,24 @@
from ..file_utils import requires_backends
class FlaxForcedBOSTokenLogitsProcessor:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxForcedEOSTokenLogitsProcessor:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxLogitsProcessor:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@ -25,6 +43,15 @@ class FlaxLogitsWarper:
requires_backends(self, ["flax"])
class FlaxMinLengthLogitsProcessor:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxTemperatureLogitsWarper:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])

View File

@ -28,7 +28,10 @@ if is_flax_available():
import jax
import jax.numpy as jnp
from transformers.generation_flax_logits_process import (
FlaxForcedBOSTokenLogitsProcessor,
FlaxForcedEOSTokenLogitsProcessor,
FlaxLogitsProcessorList,
FlaxMinLengthLogitsProcessor,
FlaxTemperatureLogitsWarper,
FlaxTopKLogitsWarper,
FlaxTopPLogitsWarper,
@ -57,8 +60,8 @@ class LogitsProcessorTest(unittest.TestCase):
temp_dist_warper_sharper = FlaxTemperatureLogitsWarper(temperature=0.5)
temp_dist_warper_smoother = FlaxTemperatureLogitsWarper(temperature=1.3)
warped_prob_sharp = jax.nn.softmax(temp_dist_warper_sharper(input_ids, scores.copy()), axis=-1)
warped_prob_smooth = jax.nn.softmax(temp_dist_warper_smoother(input_ids, scores.copy()), axis=-1)
warped_prob_sharp = jax.nn.softmax(temp_dist_warper_sharper(input_ids, scores.copy(), cur_len=None), axis=-1)
warped_prob_smooth = jax.nn.softmax(temp_dist_warper_smoother(input_ids, scores.copy(), cur_len=None), axis=-1)
# uniform distribution stays uniform
self.assertTrue(jnp.allclose(probs[0, :], warped_prob_sharp[0, :], atol=1e-3))
@ -83,7 +86,7 @@ class LogitsProcessorTest(unittest.TestCase):
top_k_warp = FlaxTopKLogitsWarper(3)
scores = top_k_warp(input_ids, ramp_logits)
scores = top_k_warp(input_ids, ramp_logits, cur_len=None)
# check that correct tokens are filtered
self.assertListEqual(jnp.isinf(scores[0]).tolist(), 7 * [True] + 3 * [False])
@ -94,7 +97,7 @@ class LogitsProcessorTest(unittest.TestCase):
top_k_warp_safety_check = FlaxTopKLogitsWarper(top_k=1, filter_value=0.0, min_tokens_to_keep=3)
ramp_logits = np.broadcast_to(np.arange(length)[None, :], (batch_size, length)).copy()
scores = top_k_warp_safety_check(input_ids, ramp_logits)
scores = top_k_warp_safety_check(input_ids, ramp_logits, cur_len=None)
# min_tokens overwrites k: 3 tokens are kept => 2 tokens are nullified
self.assertListEqual((scores == 0.0).sum(axis=-1).tolist(), [2, 2])
@ -108,7 +111,7 @@ class LogitsProcessorTest(unittest.TestCase):
dist = np.log(np.array([[0.3, 0.1, 0.1, 0.5], [0.15, 0.3, 0.3, 0.25]]))
top_p_warp = FlaxTopPLogitsWarper(0.7)
filtered_dist = np.exp(top_p_warp(input_ids, dist))
filtered_dist = np.exp(top_p_warp(input_ids, dist, cur_len=None))
# dist should be filtered to keep min num values so that sum is >= 0.7
# exp (-inf) => 0
@ -125,15 +128,81 @@ class LogitsProcessorTest(unittest.TestCase):
# make sure at least 2 tokens are kept
top_p_warp = FlaxTopPLogitsWarper(0.9, min_tokens_to_keep=2, filter_value=0.0)
filtered_dist = top_p_warp(input_ids, ramp_logits)
filtered_dist = top_p_warp(input_ids, ramp_logits, cur_len=None)
# first batch should keep three tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2.
self.assertListEqual((filtered_dist != 0.0).sum(axis=-1).tolist(), [3, 2])
def test_min_length_dist_processor(self):
vocab_size = 20
batch_size = 4
eos_token_id = 0
min_dist_processor = FlaxMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
# check that min length is applied at length 5
input_ids = ids_tensor((batch_size, 20), vocab_size=20)
cur_len = 5
scores = self._get_uniform_logits(batch_size, vocab_size)
scores_before_min_length = min_dist_processor(input_ids, scores, cur_len=cur_len)
self.assertListEqual(scores_before_min_length[:, eos_token_id].tolist(), 4 * [-float("inf")])
# check that min length is not applied anymore at length 15
scores = self._get_uniform_logits(batch_size, vocab_size)
cur_len = 15
scores_before_min_length = min_dist_processor(input_ids, scores, cur_len=cur_len)
self.assertFalse(jnp.isinf(scores_before_min_length).any())
def test_forced_bos_token_logits_processor(self):
vocab_size = 20
batch_size = 4
bos_token_id = 0
logits_processor = FlaxForcedBOSTokenLogitsProcessor(bos_token_id=bos_token_id)
# check that all scores are -inf except the bos_token_id score
input_ids = ids_tensor((batch_size, 1), vocab_size=20)
cur_len = 1
scores = self._get_uniform_logits(batch_size, vocab_size)
scores = logits_processor(input_ids, scores, cur_len=cur_len)
self.assertTrue(jnp.isneginf(scores[:, bos_token_id + 1 :]).all())
self.assertListEqual(scores[:, bos_token_id].tolist(), 4 * [0]) # score for bos_token_id shold be zero
# check that bos_token_id is not forced if current length is greater than 1
cur_len = 3
scores = self._get_uniform_logits(batch_size, vocab_size)
scores = logits_processor(input_ids, scores, cur_len=cur_len)
self.assertFalse(jnp.isinf(scores).any())
def test_forced_eos_token_logits_processor(self):
vocab_size = 20
batch_size = 4
eos_token_id = 0
max_length = 5
logits_processor = FlaxForcedEOSTokenLogitsProcessor(max_length=max_length, eos_token_id=eos_token_id)
# check that all scores are -inf except the eos_token_id when max_length is reached
input_ids = ids_tensor((batch_size, 4), vocab_size=20)
cur_len = 4
scores = self._get_uniform_logits(batch_size, vocab_size)
scores = logits_processor(input_ids, scores, cur_len=cur_len)
self.assertTrue(jnp.isneginf(scores[:, eos_token_id + 1 :]).all())
self.assertListEqual(scores[:, eos_token_id].tolist(), 4 * [0]) # score for eos_token_id should be zero
# check that eos_token_id is not forced if max_length is not reached
cur_len = 3
scores = self._get_uniform_logits(batch_size, vocab_size)
scores = logits_processor(input_ids, scores, cur_len=cur_len)
self.assertFalse(jnp.isinf(scores).any())
def test_processor_list(self):
batch_size = 4
sequence_length = 10
vocab_size = 15
eos_token_id = 2
bos_token_id = 1
max_length = 15
# dummy input_ids and scores
input_ids = ids_tensor((batch_size, sequence_length), vocab_size)
@ -147,14 +216,83 @@ class LogitsProcessorTest(unittest.TestCase):
top_k_warp = FlaxTopKLogitsWarper(3)
top_p_warp = FlaxTopPLogitsWarper(0.8)
# instantiate all logits processors
min_dist_proc = FlaxMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
bos_dist_proc = FlaxForcedBOSTokenLogitsProcessor(bos_token_id=bos_token_id)
eos_dist_proc = FlaxForcedEOSTokenLogitsProcessor(max_length=max_length, eos_token_id=eos_token_id)
cur_len = 10
# no processor list
scores = temp_dist_warp(input_ids, scores)
scores = top_k_warp(input_ids, scores)
scores = top_p_warp(input_ids, scores)
scores = temp_dist_warp(input_ids, scores, cur_len=cur_len)
scores = top_k_warp(input_ids, scores, cur_len=cur_len)
scores = top_p_warp(input_ids, scores, cur_len=cur_len)
scores = min_dist_proc(input_ids, scores, cur_len=cur_len)
scores = bos_dist_proc(input_ids, scores, cur_len=cur_len)
scores = eos_dist_proc(input_ids, scores, cur_len=cur_len)
# with processor list
processor = FlaxLogitsProcessorList([temp_dist_warp, top_k_warp, top_p_warp])
scores_comp = processor(input_ids, scores_comp)
processor = FlaxLogitsProcessorList(
[temp_dist_warp, top_k_warp, top_p_warp, min_dist_proc, bos_dist_proc, eos_dist_proc]
)
scores_comp = processor(input_ids, scores_comp, cur_len=cur_len)
# scores should be equal
self.assertTrue(jnp.allclose(scores, scores_comp, atol=1e-3))
# input_ids should never be changed
self.assertListEqual(input_ids.tolist(), input_ids_comp.tolist())
def test_processor_list_jitted(self):
batch_size = 4
sequence_length = 10
vocab_size = 15
eos_token_id = 2
bos_token_id = 1
max_length = 15
# dummy input_ids and scores
input_ids = ids_tensor((batch_size, sequence_length), vocab_size)
input_ids_comp = input_ids.copy()
scores = self._get_uniform_logits(batch_size, vocab_size)
scores_comp = scores.copy()
# instantiate all dist processors
temp_dist_warp = FlaxTemperatureLogitsWarper(temperature=0.5)
top_k_warp = FlaxTopKLogitsWarper(3)
top_p_warp = FlaxTopPLogitsWarper(0.8)
# instantiate all logits processors
min_dist_proc = FlaxMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
bos_dist_proc = FlaxForcedBOSTokenLogitsProcessor(bos_token_id=bos_token_id)
eos_dist_proc = FlaxForcedEOSTokenLogitsProcessor(max_length=max_length, eos_token_id=eos_token_id)
cur_len = 10
# no processor list
def run_no_processor_list(input_ids, scores, cur_len):
scores = temp_dist_warp(input_ids, scores, cur_len=cur_len)
scores = top_k_warp(input_ids, scores, cur_len=cur_len)
scores = top_p_warp(input_ids, scores, cur_len=cur_len)
scores = min_dist_proc(input_ids, scores, cur_len=cur_len)
scores = bos_dist_proc(input_ids, scores, cur_len=cur_len)
scores = eos_dist_proc(input_ids, scores, cur_len=cur_len)
return scores
# with processor list
def run_processor_list(input_ids, scores, cur_len):
processor = FlaxLogitsProcessorList(
[temp_dist_warp, top_k_warp, top_p_warp, min_dist_proc, bos_dist_proc, eos_dist_proc]
)
scores = processor(input_ids, scores, cur_len=cur_len)
return scores
jitted_run_no_processor_list = jax.jit(run_no_processor_list)
jitted_run_processor_list = jax.jit(run_processor_list)
scores = jitted_run_no_processor_list(input_ids, scores, cur_len)
scores_comp = jitted_run_processor_list(input_ids, scores_comp, cur_len)
# scores should be equal
self.assertTrue(jnp.allclose(scores, scores_comp, atol=1e-3))

View File

@ -110,6 +110,23 @@ class FlaxGenerationTesterMixin:
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
def test_beam_search_generate(self):
config, input_ids, _, max_length = self._get_input_ids_and_config()
config.do_sample = False
config.max_length = max_length
config.num_beams = 2
for model_class in self.all_generative_model_classes:
model = model_class(config)
generation_outputs = model.generate(input_ids).sequences
self.assertEqual(generation_outputs.shape[-1], max_length)
jit_generate = jit(model.generate)
jit_generation_outputs = jit_generate(input_ids).sequences
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
def test_sample_generate_logits_warper(self):
config, input_ids, _, max_length = self._get_input_ids_and_config()
config.do_sample = True
@ -117,6 +134,46 @@ class FlaxGenerationTesterMixin:
config.temperature = 0.8
config.top_k = 10
config.top_p = 0.3
config.min_length = 1
config.forced_bos_token_id = 8
config.forced_eos_token_id = 9
for model_class in self.all_generative_model_classes:
model = model_class(config)
generation_outputs = model.generate(input_ids).sequences
self.assertEqual(generation_outputs.shape[-1], max_length)
jit_generate = jit(model.generate)
jit_generation_outputs = jit_generate(input_ids).sequences
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
def test_greedy_generate_logits_warper(self):
config, input_ids, _, max_length = self._get_input_ids_and_config()
config.max_length = max_length
config.min_length = 1
config.forced_bos_token_id = 8
config.forced_eos_token_id = 9
for model_class in self.all_generative_model_classes:
model = model_class(config)
generation_outputs = model.generate(input_ids).sequences
self.assertEqual(generation_outputs.shape[-1], max_length)
jit_generate = jit(model.generate)
jit_generation_outputs = jit_generate(input_ids).sequences
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
def test_beam_search_generate_logits_warper(self):
config, input_ids, _, max_length = self._get_input_ids_and_config()
config.max_length = max_length
config.num_beams = 2
config.min_length = 1
config.forced_bos_token_id = 8
config.forced_eos_token_id = 9
for model_class in self.all_generative_model_classes:
model = model_class(config)
@ -168,3 +225,23 @@ class FlaxGenerationTesterMixin:
jit_generation_outputs = jit_generate(input_ids, attention_mask=attention_mask).sequences
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
def test_beam_search_generate_attn_mask(self):
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
# pad attention mask on the left
attention_mask = jax.ops.index_update(attention_mask, (0, 0), 0)
config.num_beams = 2
config.max_length = max_length
for model_class in self.all_generative_model_classes:
model = model_class(config)
generation_outputs = model.generate(input_ids, attention_mask=attention_mask).sequences
self.assertEqual(generation_outputs.shape[-1], max_length)
jit_generate = jit(model.generate)
jit_generation_outputs = jit_generate(input_ids, attention_mask=attention_mask).sequences
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())

File diff suppressed because one or more lines are too long