TF generate refactor - Beam Search (#16374)

* refactor TF beam search

* refactored generate can now properly use attention masks

* add force bos/eos logit processors
This commit is contained in:
Joao Gante 2022-04-06 18:19:34 +01:00 committed by GitHub
parent 4d10083539
commit 3f43d824b9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 796 additions and 56 deletions

View File

@ -178,6 +178,12 @@ generation.
[[autodoc]] TFRepetitionPenaltyLogitsProcessor
- __call__
[[autodoc]] TFForcedBOSTokenLogitsProcessor
- __call__
[[autodoc]] TFForcedEOSTokenLogitsProcessor
- __call__
[[autodoc]] FlaxLogitsProcessor
- __call__

View File

@ -1699,6 +1699,8 @@ if is_tf_available():
_import_structure["benchmark.benchmark_args_tf"] = ["TensorFlowBenchmarkArguments"]
_import_structure["benchmark.benchmark_tf"] = ["TensorFlowBenchmark"]
_import_structure["generation_tf_logits_process"] = [
"TFForcedBOSTokenLogitsProcessor",
"TFForcedEOSTokenLogitsProcessor",
"TFLogitsProcessor",
"TFLogitsProcessorList",
"TFLogitsWarper",
@ -3827,6 +3829,8 @@ if TYPE_CHECKING:
# Benchmarks
from .benchmark.benchmark_tf import TensorFlowBenchmark
from .generation_tf_logits_process import (
TFForcedBOSTokenLogitsProcessor,
TFForcedEOSTokenLogitsProcessor,
TFLogitsProcessor,
TFLogitsProcessorList,
TFLogitsWarper,

View File

@ -216,14 +216,10 @@ class TFMinLengthLogitsProcessor(TFLogitsProcessor):
self.min_length = min_length
self.eos_token_id = eos_token_id
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
# create boolean flag to decide if min length penalty should be applied
cur_len = input_ids.shape[-1]
apply_penalty = 1 - tf.clip_by_value(cur_len - self.min_length, 0, 1)
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
# TODO(Matt) - this if statement has to be rewritten for XLA. Leaving it now though since
# generate is not XLA - compileable anyways
if apply_penalty:
if cur_len < self.min_length:
eos_token_id_mask = tf.broadcast_to(tf.range(scores.shape[-1]) == self.eos_token_id, scores.shape)
scores = set_tensor_by_indices_to_value(scores, eos_token_id_mask, float("-inf"))
@ -259,8 +255,8 @@ class TFRepetitionPenaltyLogitsProcessor(TFLogitsProcessor):
np.put(token_penalties[i], prev_input_id, logit_penalties)
return tf.convert_to_tensor(token_penalties, dtype=tf.float32)
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
score_penalties = self._create_score_penalties(input_ids, scores)
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
score_penalties = self._create_score_penalties(input_ids[:, :cur_len], scores)
scores = tf.math.multiply(scores, score_penalties)
@ -330,12 +326,12 @@ class TFNoBadWordsLogitsProcessor(TFLogitsProcessor):
return banned_tokens
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
vocab_size = scores.shape[-1]
# calculate a list of banned tokens according to bad words
banned_tokens = self.calc_banned_bad_words_ids(input_ids)
banned_tokens = self.calc_banned_bad_words_ids(input_ids[:, :cur_len])
banned_tokens_indices_mask = []
for banned_tokens_slice in banned_tokens:
@ -365,12 +361,13 @@ class TFNoRepeatNGramLogitsProcessor(TFLogitsProcessor):
raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}")
self.ngram_size = ngram_size
def calc_banned_ngram_tokens(self, prev_input_ids, num_hypos, cur_len):
def calc_banned_ngram_tokens(self, input_ids, num_hypos, cur_len):
# Copied from fairseq for no_repeat_ngram in beam_search
if cur_len + 1 < self.ngram_size:
# return no banned tokens if we haven't generated ngram_size tokens yet
return [[] for _ in range(num_hypos)]
generated_ngrams = [{} for _ in range(num_hypos)]
prev_input_ids = input_ids[:, :cur_len]
for idx in range(num_hypos):
gen_tokens = prev_input_ids[idx].numpy().tolist()
generated_ngram = generated_ngrams[idx]
@ -388,10 +385,9 @@ class TFNoRepeatNGramLogitsProcessor(TFLogitsProcessor):
return banned_tokens
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
batch_size, vocab_size = scores.shape
cur_len = input_ids.shape[-1]
banned_tokens = self.calc_banned_ngram_tokens(input_ids, batch_size, cur_len)
# create banned_tokens boolean mask
@ -406,3 +402,66 @@ class TFNoRepeatNGramLogitsProcessor(TFLogitsProcessor):
)
return scores
class TFForcedBOSTokenLogitsProcessor(TFLogitsProcessor):
r"""
[`TFLogitsProcessor`] that enforces the specified token as the first generated token.
Args:
bos_token_id (`int`):
The id of the token to force as the first generated token.
"""
def __init__(self, bos_token_id: int):
if bos_token_id < 0:
raise ValueError(f"The forced bos token id must be a non-negative integer, got {bos_token_id}")
self.bos_token_id = bos_token_id
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
if cur_len == 1:
batch_size, num_tokens = scores.shape
# sets the score to 0 in the bos_token_id column
scores = tf.zeros((batch_size, 1))
# sets the score to -inf everywhere else
if self.bos_token_id > 0:
scores = tf.concat((tf.broadcast_to(-float("inf"), (batch_size, self.bos_token_id)), scores), axis=-1)
if self.bos_token_id < (num_tokens - 1):
scores = tf.concat(
(scores, tf.broadcast_to(-float("inf"), (batch_size, (num_tokens - 1) - self.bos_token_id))),
axis=-1,
)
return scores
class TFForcedEOSTokenLogitsProcessor(TFLogitsProcessor):
r"""
[`TFLogitsProcessor`] that enforces the specified token as the last generated token when `max_length` is reached.
Args:
max_length (`int`):
The maximum length of the sequence to be generated.
eos_token_id (`int`):
The id of the token to force as the last generated token when `max_length` is reached.
"""
def __init__(self, max_length: int, eos_token_id: int):
self.max_length = max_length
if eos_token_id < 0:
raise ValueError(f"The forced eos token id must be a non-negative integer, got {eos_token_id}")
self.eos_token_id = eos_token_id
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
if cur_len == self.max_length - 1:
batch_size, num_tokens = scores.shape
# sets the score to 0 in the eos_token_id column
scores = tf.zeros((batch_size, 1))
# sets the score to -inf everywhere else
if self.eos_token_id > 0:
scores = tf.concat((tf.broadcast_to(-float("inf"), (batch_size, self.eos_token_id)), scores), axis=-1)
if self.eos_token_id < (num_tokens - 1):
scores = tf.concat(
(scores, tf.broadcast_to(-float("inf"), (batch_size, (num_tokens - 1) - self.eos_token_id))),
axis=-1,
)
return scores

View File

@ -16,12 +16,15 @@
import inspect
from dataclasses import dataclass
from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import tensorflow as tf
from .generation_tf_logits_process import (
TFForcedBOSTokenLogitsProcessor,
TFForcedEOSTokenLogitsProcessor,
TFLogitsProcessorList,
TFMinLengthLogitsProcessor,
TFNoBadWordsLogitsProcessor,
@ -560,7 +563,7 @@ class TFGenerationMixin:
num_beams = num_beams if num_beams is not None else self.config.num_beams
do_sample = do_sample if do_sample is not None else self.config.do_sample
if num_beams == 1:
if do_sample is False or num_beams == 1:
return self._generate(
input_ids=input_ids,
max_length=max_length,
@ -586,6 +589,8 @@ class TFGenerationMixin:
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
forced_bos_token_id=forced_bos_token_id,
forced_eos_token_id=forced_eos_token_id,
)
# We cannot generate if the model does not have a LM head
@ -1377,8 +1382,8 @@ class TFGenerationMixin:
the target language token.
forced_eos_token_id (`int`, *optional*):
The id of the token to force as the last generated token when `max_length` is reached.
model_specific_kwargs:
Additional model specific kwargs will be forwarded to the `forward` function of the model.
model_kwargs:
Additional model specific kwargs will be forwarded to the `call` function of the model.
Return:
[`~utils.ModelOutput`] or `tf.Tensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` or when
@ -1452,12 +1457,20 @@ class TFGenerationMixin:
# 1. Set generation parameters if not already defined
max_length = max_length if max_length is not None else self.config.max_length
min_length = min_length if min_length is not None else self.config.min_length
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
forced_bos_token_id = (
forced_bos_token_id if forced_bos_token_id is not None else self.config.forced_bos_token_id
)
forced_eos_token_id = (
forced_eos_token_id if forced_eos_token_id is not None else self.config.forced_eos_token_id
)
output_scores = output_scores if output_scores is not None else self.config.output_scores
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@ -1489,10 +1502,13 @@ class TFGenerationMixin:
model_kwargs["output_hidden_states"] = output_hidden_states
if use_cache is not None:
model_kwargs["use_cache"] = use_cache
if attention_mask is not None:
model_kwargs["attention_mask"] = attention_mask
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.call).parameters.keys())
requires_attention_mask = "encoder_outputs" not in model_kwargs
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask:
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(input_ids, pad_token_id)
# 4. Prepare model inputs which will be used for auto-regressive generation
@ -1519,6 +1535,7 @@ class TFGenerationMixin:
# TODO(Matt, Joao, Patrick) - add more use cases here
is_greedy_gen_mode = (num_beams == 1) and do_sample is False
is_sample_gen_mode = (num_beams == 1) and do_sample is True
is_beam_gen_mode = (num_beams > 1) and do_sample is False
# 6. prepare distribution pre_processing samplers
logits_processor = self._get_logits_processor(
@ -1526,7 +1543,10 @@ class TFGenerationMixin:
no_repeat_ngram_size=no_repeat_ngram_size,
bad_words_ids=bad_words_ids,
min_length=min_length,
max_length=max_length,
eos_token_id=eos_token_id,
forced_bos_token_id=forced_bos_token_id,
forced_eos_token_id=forced_eos_token_id,
)
# 7. go into different generation modes
@ -1571,18 +1591,62 @@ class TFGenerationMixin:
**model_kwargs,
)
elif is_beam_gen_mode:
if num_beams < num_return_sequences:
raise ValueError(
"Greedy beam search decoding cannot return more sequences than it has beams. Please set "
f"num_beams >= num_return_sequences, got {num_beams} and {num_return_sequences} (respectivelly)"
)
# 8. broadcast inputs to the desired number of beams
input_ids = self._expand_to_num_beams(input_ids, num_beams=num_beams)
if "encoder_outputs" in model_kwargs:
model_kwargs["encoder_outputs"]["last_hidden_state"] = self._expand_to_num_beams(
model_kwargs["encoder_outputs"]["last_hidden_state"], num_beams=num_beams
)
if "attention_mask" in model_kwargs:
model_kwargs["attention_mask"] = self._expand_to_num_beams(
model_kwargs["attention_mask"], num_beams=num_beams
)
# 9. run beam search
return self.beam_search(
input_ids,
max_length=max_length,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
length_penalty=length_penalty,
early_stopping=early_stopping,
logits_processor=logits_processor,
return_dict_in_generate=return_dict_in_generate,
num_return_sequences=num_return_sequences,
**model_kwargs,
)
else:
# TODO(Matt, Joao, Patrick) - add more sub-generation methods here
raise NotImplementedError("Beam sampling is currently not implemented.")
@staticmethod
def _expand_to_num_beams(tensor: tf.Tensor, num_beams: int) -> tf.Tensor:
return tf.broadcast_to(tensor[:, None], (tensor.shape[0], num_beams) + tensor.shape[1:])
def _prepare_attention_mask_for_generation(
self,
input_ids: tf.Tensor,
inputs: tf.Tensor,
pad_token_id: int,
) -> tf.Tensor:
# prepare `attention_mask` if not passed
if (pad_token_id is not None) and tf.math.reduce_any(input_ids == pad_token_id):
return tf.cast(tf.math.not_equal(input_ids, pad_token_id), dtype=tf.int32)
is_input_ids = len(inputs.shape) == 2 and inputs.dtype in (tf.int32, tf.int64)
is_pad_token_in_inputs = (pad_token_id is not None) and tf.math.reduce_any(inputs == pad_token_id)
# Check if input is input_ids and padded -> only then is attention_mask defined
if is_input_ids and is_pad_token_in_inputs:
return tf.cast(tf.math.not_equal(inputs, pad_token_id), dtype=tf.int32)
else:
return tf.ones(input_ids.shape[:2], dtype=tf.int32)
return tf.ones(inputs.shape[:2], dtype=tf.int32)
def _prepare_encoder_decoder_kwargs_for_generation(self, input_ids: tf.Tensor, model_kwargs) -> Dict[str, Any]:
def _prepare_encoder_decoder_kwargs_for_generation(self, inputs_tensor: tf.Tensor, model_kwargs) -> Dict[str, Any]:
# get encoder and store encoder outputs
encoder = self.get_encoder()
@ -1595,11 +1659,9 @@ class TFGenerationMixin:
}
# vision models don't use `attention_mask`.
signature = dict(inspect.signature(encoder.call).parameters)
if "attention_mask" not in signature:
encoder_kwargs.pop("attention_mask")
encoder_outputs = encoder(input_ids, **encoder_kwargs)
encoder_kwargs["return_dict"] = True
encoder_kwargs[self.main_input_name] = inputs_tensor
encoder_outputs = encoder(**encoder_kwargs)
model_kwargs["encoder_outputs"] = encoder_outputs
return model_kwargs
@ -1757,7 +1819,10 @@ class TFGenerationMixin:
no_repeat_ngram_size: int,
bad_words_ids: List[List[int]],
min_length: int,
max_length: int,
eos_token_id: int,
forced_bos_token_id: int,
forced_eos_token_id: int,
) -> TFLogitsProcessorList:
"""
This class returns a [`TFLogitsProcessorList`] list object that contains all relevant [`TFLogitsProcessor`]
@ -1781,6 +1846,10 @@ class TFGenerationMixin:
processors.append(TFNoBadWordsLogitsProcessor(bad_words_ids, eos_token_id))
if min_length is not None and eos_token_id is not None and min_length > 0:
processors.append(TFMinLengthLogitsProcessor(min_length, eos_token_id))
if forced_bos_token_id is not None:
processors.append(TFForcedBOSTokenLogitsProcessor(forced_bos_token_id))
if forced_eos_token_id is not None:
processors.append(TFForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id))
return processors
@ -1949,7 +2018,7 @@ class TFGenerationMixin:
if not use_xla:
input_ids = tf.reshape(generated.concat(), (-1, batch_size))
input_ids = tf.transpose(input_ids[: current_pos[0]])
next_tokens_scores = logits_processor(input_ids, next_token_logits)
next_tokens_scores = logits_processor(input_ids, next_token_logits, cur_len=current_pos[0])
# argmax
next_tokens = tf.argmax(next_tokens_scores, axis=-1, output_type=tf.int32)
@ -2064,7 +2133,7 @@ class TFGenerationMixin:
input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
The sequence used as a prompt for the generation.
logits_processor (`TFLogitsProcessorList`, *optional*):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`TFLogitsProcessor`]
An instance of [`TFLogitsProcessorList`]. List of instances of class derived from [`TFLogitsProcessor`]
used to modify the prediction scores of the language modeling head applied at each generation step.
logits_warper (`TFLogitsProcessorList`, *optional*):
An instance of [`TFLogitsProcessorList`]. List of instances of class derived from [`TFLogitsWarper`]
@ -2184,7 +2253,7 @@ class TFGenerationMixin:
next_token_logits = outputs.logits[:, -1, :]
# pre-process distribution
next_token_scores = logits_processor(input_ids, next_token_logits)
next_token_scores = logits_processor(input_ids, next_token_logits, cur_len=cur_len)
next_token_scores = logits_warper(input_ids, next_token_scores)
# Store scores, attentions and hidden_states when required
@ -2259,6 +2328,527 @@ class TFGenerationMixin:
else:
return input_ids
def beam_search(
self,
input_ids: tf.Tensor,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
length_penalty: Optional[float] = None,
early_stopping: Optional[bool] = None,
logits_processor: Optional[TFLogitsProcessorList] = None,
num_return_sequences: Optional[int] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
**model_kwargs,
) -> Union[TFBeamSearchOutput, tf.Tensor]:
r"""
Generates sequences for models with a language modeling head using beam search with multinomial sampling.
Parameters:
input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
The sequence used as a prompt for the generation.
max_length (`int`, *optional*, defaults to 20):
The maximum length of the sequence to be generated.
pad_token_id (`int`, *optional*):
The id of the *padding* token.
eos_token_id (`int`, *optional*):
The id of the *end-of-sequence* token.
length_penalty (`float`, *optional*, defaults to 1.0):
Exponential penalty to the length. 1.0 means no penalty.
early_stopping (`bool`, *optional*, defaults to `False`):
Whether to stop the beam search when at least `num_beams` sentences are finished per batch or not.
logits_processor (`[TFLogitsProcessorList]`, *optional*):
An instance of [`TFLogitsProcessorList`]. List of instances of class derived from [`TFLogitsProcessor`]
used to modify the prediction scores of the language modeling head applied at each generation step.
num_return_sequences(`int`, *optional*, defaults to 1):
The number of independently computed returned sequences for each element in the batch.
output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more details.
output_hidden_states (`bool`, *optional*, defaults to `False`):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more details.
return_dict_in_generate (`bool`, *optional*, defaults to `False`):
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
model_kwargs:
Additional model specific kwargs will be forwarded to the `call` function of the model. If model is an
encoder-decoder model the kwargs should include `encoder_outputs`.
Return:
[`~generation_tf_utils.TFBeamSearchDecoderOnlyOutput`],
[`~generation_tf_utils.TFBeamSearchEncoderDecoderOutput`] or `tf.Tensor`: A `tf.Tensor` containing the
generated tokens (default behaviour) or a [`~generation_tf_utils.TFBeamSearchDecoderOnlyOutput`] if
`model.config.is_encoder_decoder=False` and `return_dict_in_generate=True` or a
[`~generation_tf_utils.TFBeamSearchEncoderDecoderOutput`] if `model.config.is_encoder_decoder=True`.
Examples:
```python
>>> from transformers import (
... AutoTokenizer,
... TFAutoModelForSeq2SeqLM,
... TFLogitsProcessorList,
... TFMinLengthLogitsProcessor,
... )
>>> import tensorflow as tf
>>> tokenizer = AutoTokenizer.from_pretrained("t5-base")
>>> model = TFAutoModelForSeq2SeqLM.from_pretrained("t5-base")
>>> encoder_input_str = "translate English to German: How old are you?"
>>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="tf").input_ids
>>> # lets run beam search using 3 beams
>>> num_beams = 3
>>> # define decoder start token ids
>>> input_ids = tf.ones((num_beams, 1), dtype=tf.int64)
>>> input_ids = input_ids * model.config.decoder_start_token_id
>>> # add encoder_outputs to model keyword arguments
>>> model_kwargs = {
... "encoder_outputs": model.get_encoder()(
... tf.repeat(encoder_input_ids, num_beams, axis=0), return_dict=True
... )
... }
>>> # instantiate logits processors
>>> logits_processor = TFLogitsProcessorList(
... [TFMinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id)]
... )
>>> outputs = model.beam_search(input_ids, logits_processor=logits_processor, **model_kwargs)
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
```"""
def flatten_beam_dim(tensor, batch_axis=0):
"""Flattens the first two dimensions of a non-scalar array."""
# ignore scalars (e.g. cache index)
if tf.rank(tensor) == 0:
return tensor
return tf.reshape(
tensor,
tensor.shape[:batch_axis]
+ [tensor.shape[batch_axis] * tensor.shape[batch_axis + 1]]
+ tensor.shape[batch_axis + 2 :],
)
def unflatten_beam_dim(tensor, batch_size, num_beams, batch_axis=0):
"""Unflattens the first, flat batch*beam dimension of a non-scalar array."""
# ignore scalars (e.g. cache index)
if tf.rank(tensor) == 0:
return tensor
return tf.reshape(
tensor, tensor.shape[:batch_axis] + [batch_size, num_beams] + tensor.shape[batch_axis + 1 :]
)
def gather_beams(nested, beam_indices, batch_axis=0):
"""Gathers the beam slices indexed by beam_indices into new beam array."""
def gather_fn(tensor):
# ignore scalars (e.g. cache index)
if tf.rank(tensor) == 0:
return tensor
else:
if batch_axis > 0:
# pushes all dimentions before the batch to the end, so we get (batch, beam_id, ...)
perm = [axis for axis in range(tf.rank(tensor)) if axis >= batch_axis] + list(
range(batch_axis)
)
tensor = tf.transpose(tensor, perm=perm)
gathered_tensor = tf.gather(params=tensor, indices=beam_indices, axis=1, batch_dims=1)
if batch_axis > 0:
# transposes back to the original dimensions
perm = [axis for axis in range(tf.rank(tensor)) if axis >= batch_axis] + list(
range(batch_axis)
)
perm = tf.math.invert_permutation(perm)
gathered_tensor = tf.transpose(gathered_tensor, perm=perm)
return gathered_tensor
return tf.nest.map_structure(gather_fn, nested)
# 1. init beam_search values
logits_processor = logits_processor if logits_processor is not None else TFLogitsProcessorList()
max_length = max_length if max_length is not None else self.config.max_length
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
num_return_sequences = (
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
output_scores = output_scores if output_scores is not None else self.config.output_scores
return_dict_in_generate = (
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
)
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
use_xla = not tf.executing_eagerly()
# TODO (Joao): fix cache format or find programatic way to detect cache index
# GPT2 and other models has a slightly different cache structure, with a different batch axis
model_name = str(self.decoder) if "EncoderDecoder" in str(self) else str(self)
cache_batch_axis = 1 if any([model_prefix in model_name for model_prefix in ("TFGPT2", "TFCTRL")]) else 0
# 2. init `attentions`, `hidden_states`, and `scores` tuples
scores = [] if (return_dict_in_generate and output_scores) else None
decoder_attentions = [] if (return_dict_in_generate and output_attentions) else None
cross_attentions = [] if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = [] if (return_dict_in_generate and output_hidden_states) else None
# 3. init tensors to use for "xla-compileable" generate function
batch_size, num_beams, cur_len = input_ids.shape
# per batch, beam-item holding current token in loop, pre-populated with `pad_token_id`
sequences = tf.TensorArray(
element_shape=(batch_size, num_beams),
dtype=tf.int32,
dynamic_size=False,
size=max_length,
clear_after_read=False,
)
running_sequences = tf.TensorArray(
element_shape=(batch_size, num_beams),
dtype=tf.int32,
dynamic_size=False,
size=max_length,
clear_after_read=False,
)
intermediary_running_sequences = tf.TensorArray(
element_shape=(batch_size, num_beams * 2),
dtype=tf.int32,
dynamic_size=False,
size=max_length,
clear_after_read=False,
)
for i in range(max_length):
sequences = sequences.write(i, tf.broadcast_to(pad_token_id, (batch_size, num_beams)))
running_sequences = running_sequences.write(i, tf.broadcast_to(pad_token_id, (batch_size, num_beams)))
intermediary_running_sequences = intermediary_running_sequences.write(
i, tf.broadcast_to(pad_token_id, (batch_size, num_beams * 2))
)
# write prompt to running_sequences
for i in range(cur_len):
running_sequences = running_sequences.write(i, input_ids[:, :, i])
# per batch,beam-item state bit indicating if sentence has finished.
is_sent_finished = tf.zeros((batch_size, num_beams), dtype=tf.bool)
# per batch, beam-item score, logprobs
running_scores = tf.tile(
tf.expand_dims(tf.convert_to_tensor([0.0] + [-1.0e9] * (num_beams - 1)), axis=0), [batch_size, 1]
)
scores = tf.ones((batch_size, num_beams)) * -1.0e9
# flatten beam dim
if "encoder_outputs" in model_kwargs:
model_kwargs["encoder_outputs"]["last_hidden_state"] = flatten_beam_dim(
model_kwargs["encoder_outputs"]["last_hidden_state"]
)
if "attention_mask" in model_kwargs:
model_kwargs["attention_mask"] = flatten_beam_dim(model_kwargs["attention_mask"])
# 4. define "xla-compile-able" stop-condition and auto-regressive function
# define stop-condition and auto-regressive function
def beam_search_cond_fn(
cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, model_kwargs
):
"""
Beam Search termination condition function -- halts the generation loop if any of these conditions becomes
False
"""
# 1. is less than max length?
not_max_length_yet = cur_len < max_length
# 2. can the new beams still improve?
best_running_score = running_scores[:, :1] / (max_length**length_penalty)
worst_finished_score = tf.where(
is_sent_finished, tf.math.reduce_min(scores, axis=1, keepdims=True), -1.0e9
)
improvement_still_possible = tf.math.reduce_all(worst_finished_score < best_running_score)
# 3. is there still a beam that has not finished?
still_open_beam = ~(tf.math.reduce_all(is_sent_finished) & early_stopping)
return not_max_length_yet & (still_open_beam | improvement_still_possible)
def beam_search_body_fn(
cur_len,
running_sequences,
running_scores,
sequences,
scores,
is_sent_finished,
model_kwargs,
input_ids_length=1,
intermediary_running_sequences=None,
):
"""
Beam Search iterative update function -- each iteration adds a new token and updates the best sequences
seen so far
"""
# TODO (joao): this loop is probably faster with gather/scatters, instead of using `tf.TensorArray`.
# Alternativelly, attempt to rewrite function with permuted axis, when enabling XLA.
# 1. Forward current tokens
# TF places the dynamic dimension (seq_len) in the first axis, we want it in the last
running_sequences_seq_last = tf.transpose(running_sequences.stack(), perm=[1, 2, 0])
input_token = tf.slice(
running_sequences_seq_last,
(0, 0, cur_len - input_ids_length),
(batch_size, num_beams, input_ids_length),
)
model_inputs = self.prepare_inputs_for_generation(
flatten_beam_dim(input_token), use_xla=use_xla, **model_kwargs
)
model_outputs = self(
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
logits = unflatten_beam_dim(model_outputs.logits[:, -1], batch_size, num_beams)
# Store scores, attentions and hidden_states when required
if not use_xla and return_dict_in_generate:
if output_scores:
scores.append(model_outputs.logits[:, -1])
if output_attentions and self.config.is_encoder_decoder:
decoder_attentions.append(model_outputs.decoder_attentions)
elif output_attentions and not self.config.is_encoder_decoder:
decoder_attentions.append(model_outputs.attentions)
if self.config.is_encoder_decoder:
cross_attentions.append(model_outputs.cross_attentions)
if output_hidden_states and self.config.is_encoder_decoder:
decoder_hidden_states.append(model_outputs.decoder_hidden_states)
elif output_hidden_states and self.config.is_encoder_decoder:
decoder_hidden_states.append(model_outputs.hidden_states)
# 2. Compute log probs
# get log probabilities from logits, process logits with processors (*e.g.* min_length, ...), and
# add new logprobs to existing running logprobs scores.
log_probs = tf.nn.log_softmax(logits)
log_probs = logits_processor(
flatten_beam_dim(running_sequences_seq_last), flatten_beam_dim(log_probs), cur_len=cur_len
)
log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams)
log_probs = log_probs + tf.expand_dims(running_scores, axis=2)
vocab_size = log_probs.shape[2]
log_probs = tf.reshape(log_probs, (batch_size, num_beams * vocab_size))
# 3. Retrieve top-K
# Each item in batch has num_beams * vocab_size candidate sequences. For each item, get the top 2*k
# candidates with the highest log-probabilities. We gather the top 2*K beams here so that even if the
# best K sequences reach EOS simultaneously, we have another K sequences remaining to continue the live
# beam search.
# Gather the top 2*K scores from _all_ beams.
# Gather 2*k top beams.
# Recover the beam index by floor division.
# Recover token id by modulo division and expand Id array for broadcasting.
# Update sequences for the 2*K top-k new sequences.
beams_to_keep = 2 * num_beams
topk_log_probs, topk_indices = tf.math.top_k(log_probs, k=beams_to_keep)
topk_beam_indices = topk_indices // vocab_size
topk_running_sequences_seq_last = gather_beams(running_sequences_seq_last, topk_beam_indices)
topk_ids = topk_indices % vocab_size
# writes the new token
intermediary_running_sequences = intermediary_running_sequences.unstack(
tf.transpose(topk_running_sequences_seq_last, perm=[2, 0, 1])
)
topk_sequences = intermediary_running_sequences.write(cur_len, topk_ids)
topk_sequences_seq_last = tf.transpose(topk_sequences.stack(), perm=[1, 2, 0])
# 4. Check which sequences have ended
# Update current sequences: Did the top `num_beams` sequences reach an end marker?
# To prevent these just finished sequences from being added to the current sequences
# set of active beam search sequences, set their log probs to a very large negative value.
eos_in_next_token = topk_sequences_seq_last[:, :, cur_len] == eos_token_id
if eos_token_id is None:
eos_in_next_token = tf.broadcast_to(eos_in_next_token, topk_sequences_seq_last[:, :, cur_len].shape)
did_topk_just_finished = eos_in_next_token & tf.broadcast_to(
tf.concat((tf.ones((num_beams), dtype=tf.bool), tf.zeros((num_beams), dtype=tf.bool)), axis=0),
eos_in_next_token.shape,
)
# non-top `num_beams` eos tokens can't be used to finish a beam, but the others can't be used in the next
# running sentences either
running_topk_log_probs = topk_log_probs + tf.cast(eos_in_next_token, tf.float32) * -1.0e9
# 5. Get running sequences scores for next
# Determine the top k beam indices (from top 2*k beams) from log probs and gather top k beams
# (from top 2*k beams).
next_topk_indices = tf.math.top_k(running_topk_log_probs, k=num_beams)[1]
next_running_sequences_seq_last, next_running_scores = gather_beams(
[topk_sequences_seq_last, running_topk_log_probs], next_topk_indices
)
# 6. Process topk logits
# Further process log probs:
# - add length penalty
# - make sure no scores can be added anymore if beam is full
# - make sure still running sequences cannot be chosen as finalized beam
topk_log_probs = topk_log_probs / (cur_len**length_penalty)
beams_in_batch_are_full = (
tf.broadcast_to(
tf.math.reduce_all(is_sent_finished, axis=-1, keepdims=True), did_topk_just_finished.shape
)
& early_stopping
)
add_penalty = ~did_topk_just_finished | beams_in_batch_are_full
topk_log_probs += tf.cast(add_penalty, tf.float32) * -1.0e9
# 7. Get scores, sequences, is sentence finished for next.
# Combine sequences, scores, and flags along the beam dimension and compare new finished sequence scores
# to existing finished scores and select the best from the new set of beams
sequences_seq_last = tf.transpose(sequences.stack(), perm=[1, 2, 0])
merged_sequences = tf.concat([sequences_seq_last, topk_sequences_seq_last], axis=1)
merged_scores = tf.concat([scores, topk_log_probs], axis=1)
merged_is_sent_finished = tf.concat([is_sent_finished, did_topk_just_finished], axis=1)
topk_merged_indices = tf.math.top_k(merged_scores, k=num_beams)[1]
next_sequences_seq_last, next_scores, next_is_sent_finished = gather_beams(
[merged_sequences, merged_scores, merged_is_sent_finished], topk_merged_indices
)
# 8. Prepare data for the next iteration
# Determine the top k beam indices from the original set of all beams. With these, gather the top k
# beam-associated caches.
if "past_key_values" in model_outputs:
cache = tf.nest.map_structure(
lambda tensor: unflatten_beam_dim(tensor, batch_size, num_beams, batch_axis=cache_batch_axis),
model_outputs.past_key_values,
)
next_running_indices = gather_beams(topk_beam_indices, next_topk_indices)
next_cache = gather_beams(cache, next_running_indices, batch_axis=cache_batch_axis)
model_outputs["past_key_values"] = tf.nest.map_structure(
lambda tensor: flatten_beam_dim(tensor, batch_axis=cache_batch_axis), next_cache
)
if use_xla:
next_model_kwargs = self._update_model_kwargs_for_xla_generation(
model_outputs, model_kwargs, cur_len, max_length
)
else:
next_model_kwargs = self._update_model_kwargs_for_generation(
model_outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
# if we don't cache past key values we need the whole input
if model_kwargs.get("past", None) is None:
input_ids_length = cur_len + 1
# let's throw out `past` since we don't want `None` tensors
model_kwargs.pop("past", None)
# 9. Prepare the `tf.TensorArray` for the next iteration
next_sequences = sequences.unstack(tf.transpose(next_sequences_seq_last, perm=[2, 0, 1]))
next_running_sequences = running_sequences.unstack(
tf.transpose(next_running_sequences_seq_last, perm=[2, 0, 1])
)
return (
cur_len + 1,
next_running_sequences,
next_running_scores,
next_sequences,
next_scores,
next_is_sent_finished,
next_model_kwargs,
)
# 5. run generation
# Adds the `intermediary_running_sequences` TensorArray into the body, needed as a scratchpad
beam_search_body_fn = partial(
beam_search_body_fn, intermediary_running_sequences=intermediary_running_sequences
)
# 1st generation step has to be run before to initialize `past`
beam_search_body_fn_first_iter = partial(beam_search_body_fn, input_ids_length=cur_len)
(
cur_len,
running_sequences,
running_scores,
sequences,
scores,
is_sent_finished,
model_kwargs,
) = beam_search_body_fn_first_iter(
cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, model_kwargs
)
# 2-to-n generation steps can then be run in autoregressive fashion (only in case 1st generation step does
# NOT yield EOS token though)
if beam_search_cond_fn(
cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, model_kwargs
):
maximum_iterations = max_length - cur_len
cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, _ = tf.while_loop(
beam_search_cond_fn,
beam_search_body_fn,
(cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, model_kwargs),
maximum_iterations=maximum_iterations,
)
# 6. prepare outputs
# convert the sequneces to tf.Tensor with shape (batch_size, num_beams, seq_len)
sequences_seq_last = tf.transpose(sequences.stack(), perm=[1, 2, 0])
running_sequences_seq_last = tf.transpose(running_sequences.stack(), perm=[1, 2, 0])
# Account for the edge-case where there are no finished sequences for a particular batch item. If so, return
# running sequences for that batch item.
none_finished = tf.math.reduce_any(is_sent_finished, axis=1)
sequences_seq_last = tf.where(none_finished[:, None, None], sequences_seq_last, running_sequences_seq_last)
scores = tf.where(none_finished[:, None], scores, running_scores)
# Take best beams for each batch (the score is sorted in ascending order)
sequences_seq_last = flatten_beam_dim(sequences_seq_last[:, :num_return_sequences, :])
scores = flatten_beam_dim(scores[:, :num_return_sequences])
if not use_xla:
# Cut for backward compatibility
sequences_seq_last = sequences_seq_last[:, :cur_len]
if return_dict_in_generate:
if self.config.is_encoder_decoder:
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
encoder_hidden_states = (
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
)
return TFBeamSearchEncoderDecoderOutput(
sequences=sequences_seq_last,
scores=scores,
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states,
)
else:
return TFBeamSearchDecoderOnlyOutput(
sequences=sequences_seq_last,
scores=scores,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
)
else:
return sequences_seq_last
def _create_next_token_logits_penalties(input_ids, logits, repetition_penalty):
# create logit penalties for already seen input_ids
@ -2445,7 +3035,6 @@ class BeamHypotheses(object):
If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst
one in the heap, then we are done with this sentence.
"""
if len(self) < self.num_beams:
return False
elif self.early_stopping:

View File

@ -1245,7 +1245,10 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
no_repeat_ngram_size=no_repeat_ngram_size,
bad_words_ids=bad_words_ids,
min_length=min_length,
max_length=max_length,
eos_token_id=eos_token_id,
forced_bos_token_id=None,
forced_eos_token_id=None,
)
model_kwargs["attention_mask"] = context_attention_mask

View File

@ -17,6 +17,20 @@ class TensorFlowBenchmark(metaclass=DummyObject):
requires_backends(self, ["tf"])
class TFForcedBOSTokenLogitsProcessor(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFForcedEOSTokenLogitsProcessor(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFLogitsProcessor(metaclass=DummyObject):
_backends = ["tf"]

View File

@ -472,14 +472,14 @@ class LogitsProcessorTest(unittest.TestCase):
logits_processor = ForcedEOSTokenLogitsProcessor(max_length=max_length, eos_token_id=eos_token_id)
# check that all scores are -inf except the eos_token_id when max_length is reached
# check that all scores are -inf except the eos_token_id when max_length-1 is reached
input_ids = ids_tensor((batch_size, 4), vocab_size=20)
scores = self._get_uniform_logits(batch_size, vocab_size)
scores = logits_processor(input_ids, scores)
self.assertTrue(torch.isneginf(scores[:, eos_token_id + 1 :]).all())
self.assertListEqual(scores[:, eos_token_id].tolist(), 4 * [0]) # score for eos_token_id should be zero
# check that eos_token_id is not forced if max_length is not reached
# check that eos_token_id is not forced if max_length-1 is not reached
input_ids = ids_tensor((batch_size, 3), vocab_size=20)
scores = self._get_uniform_logits(batch_size, vocab_size)
scores = logits_processor(input_ids, scores)

View File

@ -26,6 +26,8 @@ if is_tf_available():
import tensorflow as tf
from transformers.generation_tf_logits_process import (
TFForcedBOSTokenLogitsProcessor,
TFForcedEOSTokenLogitsProcessor,
TFLogitsProcessorList,
TFMinLengthLogitsProcessor,
TFNoBadWordsLogitsProcessor,
@ -43,7 +45,7 @@ if is_tf_available():
@require_tf
class TFLogitsProcessorTest(unittest.TestCase):
def _get_uniform_logits(self, batch_size: int, length: int):
scores = np.ones((batch_size, length), dtype=np.float32) / length
scores = tf.ones((batch_size, length), dtype=tf.float32) / length
return scores
def test_min_length_dist_processor(self):
@ -54,15 +56,17 @@ class TFLogitsProcessorTest(unittest.TestCase):
min_dist_processor = TFMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
# check that min length is applied at length 5
input_ids = ids_tensor((batch_size, 5), vocab_size=20)
cur_len = 5
input_ids = ids_tensor((batch_size, cur_len), vocab_size=20)
scores = self._get_uniform_logits(batch_size, vocab_size)
scores_before_min_length = min_dist_processor(input_ids, scores)
scores_before_min_length = min_dist_processor(input_ids, scores, cur_len)
self.assertListEqual(scores_before_min_length[:, eos_token_id].numpy().tolist(), 4 * [-float("inf")])
# check that min length is not applied anymore at length 15
input_ids = ids_tensor((batch_size, 15), vocab_size=20)
cur_len = 15
input_ids = ids_tensor((batch_size, cur_len), vocab_size=20)
scores = self._get_uniform_logits(batch_size, vocab_size)
scores_before_min_length = min_dist_processor(input_ids, scores)
scores_before_min_length = min_dist_processor(input_ids, scores, cur_len)
self.assertFalse(tf.math.reduce_any(tf.math.is_inf(scores_before_min_length)).numpy())
def test_temperature_dist_warper(self):
@ -72,8 +76,10 @@ class TFLogitsProcessorTest(unittest.TestCase):
scores = self._get_uniform_logits(batch_size=2, length=length)
# tweak scores to not be uniform anymore
scores = scores.numpy()
scores[1, 5] = (1 / length) + 0.1 # peak, 1st batch
scores[1, 10] = (1 / length) - 0.4 # valley, 1st batch
scores = tf.convert_to_tensor(scores)
# compute softmax
probs = tf.nn.softmax(scores, axis=-1)
@ -97,8 +103,11 @@ class TFLogitsProcessorTest(unittest.TestCase):
self.assertLess(tf.math.reduce_min(probs[1, :]), tf.math.reduce_min(warped_prob_smooth[1, :]))
def test_repetition_penalty_dist_process(self):
input_ids = tf.constant([[0, 1], [5, 0]], dtype=tf.int32)
vocab_size = 10
cur_len = 2
input_ids = tf.constant([[0, 1], [5, 0]], dtype=tf.int32)
self.assertEqual(cur_len, input_ids.shape[1])
scores = self._get_uniform_logits(batch_size=2, length=vocab_size)
@ -109,7 +118,7 @@ class TFLogitsProcessorTest(unittest.TestCase):
rep_penalty_proc = TFRepetitionPenaltyLogitsProcessor(penalty=2.0)
scores = rep_penalty_proc(input_ids, tf.identity(scores))
scores = rep_penalty_proc(input_ids, tf.identity(scores), cur_len)
# check that values were correctly changed
self.assertAlmostEqual(scores[0, 0].numpy(), -(1 / vocab_size) * 2)
@ -188,15 +197,18 @@ class TFLogitsProcessorTest(unittest.TestCase):
def test_no_repeat_ngram_dist_processor(self):
vocab_size = 3
batch_size = 2
cur_len = 4
input_ids = tf.constant([[1, 1, 2, 1], [0, 1, 0, 1]], dtype=tf.int32)
self.assertEqual(cur_len, input_ids.shape[1])
scores = self._get_uniform_logits(batch_size, vocab_size)
no_repeat_proc_2_gram = TFNoRepeatNGramLogitsProcessor(2)
no_repeat_proc_3_gram = TFNoRepeatNGramLogitsProcessor(3)
filtered_scores_2_gram = no_repeat_proc_2_gram(input_ids, tf.identity(scores))
filtered_scores_3_gram = no_repeat_proc_3_gram(input_ids, tf.identity(scores))
filtered_scores_2_gram = no_repeat_proc_2_gram(input_ids, tf.identity(scores), cur_len)
filtered_scores_3_gram = no_repeat_proc_3_gram(input_ids, tf.identity(scores), cur_len)
# 2-gram would forbid 2nd and 3rd token (1,2) at 1st batch and 1st token (0) at 2nd batch
self.assertListEqual(
@ -212,14 +224,17 @@ class TFLogitsProcessorTest(unittest.TestCase):
vocab_size = 5
batch_size = 2
eos_token_id = 4
cur_len = 4
input_ids = tf.constant([[0, 1, 3, 1], [0, 1, 0, 1]], dtype=tf.int32)
self.assertEqual(cur_len, input_ids.shape[1])
bad_word_tokens = [[1], [4], [1, 0], [0, 1, 2], [1, 3, 1, 3]]
scores = self._get_uniform_logits(batch_size, vocab_size)
no_bad_words_dist_proc = TFNoBadWordsLogitsProcessor(bad_words_ids=bad_word_tokens, eos_token_id=eos_token_id)
filtered_scores = no_bad_words_dist_proc(input_ids, tf.identity(scores))
filtered_scores = no_bad_words_dist_proc(input_ids, tf.identity(scores), cur_len)
# batch 1: 1st, 2nd, and 4th (0, 1, 3) token are forbidden
# batch 2: 1st, 2nd, and 3rd (0, 1, 2) token are forbidden
@ -228,14 +243,65 @@ class TFLogitsProcessorTest(unittest.TestCase):
[[True, True, False, True, True], [True, True, True, False, True]],
)
def test_forced_bos_token_logits_processor(self):
vocab_size = 20
batch_size = 4
bos_token_id = 0
logits_processor = TFForcedBOSTokenLogitsProcessor(bos_token_id=bos_token_id)
# check that all scores are -inf except the bos_token_id score
cur_len = 1
input_ids = ids_tensor((batch_size, cur_len), vocab_size=20)
scores = self._get_uniform_logits(batch_size, vocab_size)
scores = logits_processor(input_ids, scores, cur_len)
self.assertTrue(
tf.math.reduce_all(tf.math.is_inf(scores[:, bos_token_id + 1 :]) & (scores[:, bos_token_id + 1 :] < 0))
)
self.assertListEqual(scores[:, bos_token_id].numpy().tolist(), 4 * [0]) # score for bos_token_id shold be zero
# check that bos_token_id is not forced if current length is greater than 1
cur_len = 4
input_ids = ids_tensor((batch_size, cur_len), vocab_size=20)
scores = self._get_uniform_logits(batch_size, vocab_size)
scores = logits_processor(input_ids, scores, cur_len)
self.assertFalse(tf.math.reduce_any(tf.math.is_inf((scores))))
def test_forced_eos_token_logits_processor(self):
vocab_size = 20
batch_size = 4
eos_token_id = 0
max_length = 5
logits_processor = TFForcedEOSTokenLogitsProcessor(max_length=max_length, eos_token_id=eos_token_id)
# check that all scores are -inf except the eos_token_id when max_length-1 is reached
cur_len = 4
input_ids = ids_tensor((batch_size, cur_len), vocab_size=20)
scores = self._get_uniform_logits(batch_size, vocab_size)
scores = logits_processor(input_ids, scores, cur_len)
self.assertTrue(
tf.math.reduce_all(tf.math.is_inf(scores[:, eos_token_id + 1 :]) & (scores[:, eos_token_id + 1 :] < 0))
)
self.assertListEqual(
scores[:, eos_token_id].numpy().tolist(), 4 * [0]
) # score for eos_token_id should be zero
# check that eos_token_id is not forced if max_length-1 is not reached
cur_len = 3
input_ids = ids_tensor((batch_size, cur_len), vocab_size=20)
scores = self._get_uniform_logits(batch_size, vocab_size)
scores = logits_processor(input_ids, scores, cur_len)
self.assertFalse(tf.math.reduce_any(tf.math.is_inf((scores))))
def test_processor_list(self):
batch_size = 4
sequence_length = 10
cur_len = 10
vocab_size = 15
eos_token_id = 0
# dummy input_ids and scores
input_ids = ids_tensor((batch_size, sequence_length), vocab_size)
input_ids = ids_tensor((batch_size, cur_len), vocab_size)
input_ids_comp = tf.identity(input_ids)
scores = self._get_uniform_logits(batch_size, vocab_size)
@ -251,13 +317,13 @@ class TFLogitsProcessorTest(unittest.TestCase):
no_bad_words_dist_proc = TFNoBadWordsLogitsProcessor(bad_words_ids=[[1]], eos_token_id=eos_token_id)
# no processor list
scores = min_dist_proc(input_ids, scores)
scores = min_dist_proc(input_ids, scores, cur_len)
scores = temp_dist_warp(input_ids, scores)
scores = rep_penalty_proc(input_ids, scores)
scores = rep_penalty_proc(input_ids, scores, cur_len)
scores = top_k_warp(input_ids, scores)
scores = top_p_warp(input_ids, scores)
scores = no_repeat_proc(input_ids, scores)
scores = no_bad_words_dist_proc(input_ids, scores)
scores = no_repeat_proc(input_ids, scores, cur_len)
scores = no_bad_words_dist_proc(input_ids, scores, cur_len)
# with processor list
processor = TFLogitsProcessorList(
@ -271,7 +337,7 @@ class TFLogitsProcessorTest(unittest.TestCase):
no_bad_words_dist_proc,
]
)
scores_comp = processor(input_ids, scores_comp)
scores_comp = processor(input_ids, scores_comp, cur_len=cur_len)
# remove inf
scores = set_tensor_by_indices_to_value(scores, tf.math.is_inf(scores), -1e9)

View File

@ -536,7 +536,6 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
"bad_words_ids": [tokenizer("is").input_ids, tokenizer("angry about").input_ids],
"no_repeat_ngram_size": 2,
"do_sample": False,
"repetition_penalty": 1.3,
"num_beams": 2,
}
@ -544,8 +543,8 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
expected_output_string = [
"Today is a beautiful day and I hope you enjoy it.\nI am very happy to announce that",
"Yesterday was the first time I've ever seen a game where you can play with",
"Today is a beautiful day and a great day for all of us.\n\nIm",
"Yesterday was the first day of the year for the second time in a row,",
]
self.assertListEqual(output_strings, expected_output_string)

View File

@ -508,7 +508,7 @@ class TFSpeech2TextModelTest(TFModelTesterMixin, unittest.TestCase):
# if bos token id is not defined model needs input_ids, num_return_sequences = 1
self._check_generated_ids(model.generate(input_features, do_sample=True, num_beams=2))
with self.assertRaises(AssertionError):
with self.assertRaises(ValueError):
# generating more sequences than having beams leads is not possible
model.generate(input_features, do_sample=False, num_return_sequences=3, num_beams=2)

View File

@ -1179,7 +1179,7 @@ class TFModelTesterMixin:
# num_return_sequences = 1
self._check_generated_ids(model.generate(do_sample=True, max_length=5, num_beams=2))
with self.assertRaises(AssertionError):
with self.assertRaises(ValueError):
# generating more sequences than having beams leads is not possible
model.generate(input_ids, do_sample=False, num_return_sequences=3, num_beams=2)