mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
TF generate refactor - Beam Search (#16374)
* refactor TF beam search * refactored generate can now properly use attention masks * add force bos/eos logit processors
This commit is contained in:
parent
4d10083539
commit
3f43d824b9
@ -178,6 +178,12 @@ generation.
|
||||
[[autodoc]] TFRepetitionPenaltyLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] TFForcedBOSTokenLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] TFForcedEOSTokenLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] FlaxLogitsProcessor
|
||||
- __call__
|
||||
|
||||
|
@ -1699,6 +1699,8 @@ if is_tf_available():
|
||||
_import_structure["benchmark.benchmark_args_tf"] = ["TensorFlowBenchmarkArguments"]
|
||||
_import_structure["benchmark.benchmark_tf"] = ["TensorFlowBenchmark"]
|
||||
_import_structure["generation_tf_logits_process"] = [
|
||||
"TFForcedBOSTokenLogitsProcessor",
|
||||
"TFForcedEOSTokenLogitsProcessor",
|
||||
"TFLogitsProcessor",
|
||||
"TFLogitsProcessorList",
|
||||
"TFLogitsWarper",
|
||||
@ -3827,6 +3829,8 @@ if TYPE_CHECKING:
|
||||
# Benchmarks
|
||||
from .benchmark.benchmark_tf import TensorFlowBenchmark
|
||||
from .generation_tf_logits_process import (
|
||||
TFForcedBOSTokenLogitsProcessor,
|
||||
TFForcedEOSTokenLogitsProcessor,
|
||||
TFLogitsProcessor,
|
||||
TFLogitsProcessorList,
|
||||
TFLogitsWarper,
|
||||
|
@ -216,14 +216,10 @@ class TFMinLengthLogitsProcessor(TFLogitsProcessor):
|
||||
self.min_length = min_length
|
||||
self.eos_token_id = eos_token_id
|
||||
|
||||
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
|
||||
# create boolean flag to decide if min length penalty should be applied
|
||||
cur_len = input_ids.shape[-1]
|
||||
apply_penalty = 1 - tf.clip_by_value(cur_len - self.min_length, 0, 1)
|
||||
|
||||
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
|
||||
# TODO(Matt) - this if statement has to be rewritten for XLA. Leaving it now though since
|
||||
# generate is not XLA - compileable anyways
|
||||
if apply_penalty:
|
||||
if cur_len < self.min_length:
|
||||
eos_token_id_mask = tf.broadcast_to(tf.range(scores.shape[-1]) == self.eos_token_id, scores.shape)
|
||||
scores = set_tensor_by_indices_to_value(scores, eos_token_id_mask, float("-inf"))
|
||||
|
||||
@ -259,8 +255,8 @@ class TFRepetitionPenaltyLogitsProcessor(TFLogitsProcessor):
|
||||
np.put(token_penalties[i], prev_input_id, logit_penalties)
|
||||
return tf.convert_to_tensor(token_penalties, dtype=tf.float32)
|
||||
|
||||
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
|
||||
score_penalties = self._create_score_penalties(input_ids, scores)
|
||||
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
|
||||
score_penalties = self._create_score_penalties(input_ids[:, :cur_len], scores)
|
||||
|
||||
scores = tf.math.multiply(scores, score_penalties)
|
||||
|
||||
@ -330,12 +326,12 @@ class TFNoBadWordsLogitsProcessor(TFLogitsProcessor):
|
||||
|
||||
return banned_tokens
|
||||
|
||||
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
|
||||
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
|
||||
|
||||
vocab_size = scores.shape[-1]
|
||||
|
||||
# calculate a list of banned tokens according to bad words
|
||||
banned_tokens = self.calc_banned_bad_words_ids(input_ids)
|
||||
banned_tokens = self.calc_banned_bad_words_ids(input_ids[:, :cur_len])
|
||||
|
||||
banned_tokens_indices_mask = []
|
||||
for banned_tokens_slice in banned_tokens:
|
||||
@ -365,12 +361,13 @@ class TFNoRepeatNGramLogitsProcessor(TFLogitsProcessor):
|
||||
raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}")
|
||||
self.ngram_size = ngram_size
|
||||
|
||||
def calc_banned_ngram_tokens(self, prev_input_ids, num_hypos, cur_len):
|
||||
def calc_banned_ngram_tokens(self, input_ids, num_hypos, cur_len):
|
||||
# Copied from fairseq for no_repeat_ngram in beam_search
|
||||
if cur_len + 1 < self.ngram_size:
|
||||
# return no banned tokens if we haven't generated ngram_size tokens yet
|
||||
return [[] for _ in range(num_hypos)]
|
||||
generated_ngrams = [{} for _ in range(num_hypos)]
|
||||
prev_input_ids = input_ids[:, :cur_len]
|
||||
for idx in range(num_hypos):
|
||||
gen_tokens = prev_input_ids[idx].numpy().tolist()
|
||||
generated_ngram = generated_ngrams[idx]
|
||||
@ -388,10 +385,9 @@ class TFNoRepeatNGramLogitsProcessor(TFLogitsProcessor):
|
||||
|
||||
return banned_tokens
|
||||
|
||||
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
|
||||
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
|
||||
|
||||
batch_size, vocab_size = scores.shape
|
||||
cur_len = input_ids.shape[-1]
|
||||
banned_tokens = self.calc_banned_ngram_tokens(input_ids, batch_size, cur_len)
|
||||
|
||||
# create banned_tokens boolean mask
|
||||
@ -406,3 +402,66 @@ class TFNoRepeatNGramLogitsProcessor(TFLogitsProcessor):
|
||||
)
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
class TFForcedBOSTokenLogitsProcessor(TFLogitsProcessor):
|
||||
r"""
|
||||
[`TFLogitsProcessor`] that enforces the specified token as the first generated token.
|
||||
|
||||
Args:
|
||||
bos_token_id (`int`):
|
||||
The id of the token to force as the first generated token.
|
||||
"""
|
||||
|
||||
def __init__(self, bos_token_id: int):
|
||||
if bos_token_id < 0:
|
||||
raise ValueError(f"The forced bos token id must be a non-negative integer, got {bos_token_id}")
|
||||
self.bos_token_id = bos_token_id
|
||||
|
||||
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
|
||||
if cur_len == 1:
|
||||
batch_size, num_tokens = scores.shape
|
||||
# sets the score to 0 in the bos_token_id column
|
||||
scores = tf.zeros((batch_size, 1))
|
||||
# sets the score to -inf everywhere else
|
||||
if self.bos_token_id > 0:
|
||||
scores = tf.concat((tf.broadcast_to(-float("inf"), (batch_size, self.bos_token_id)), scores), axis=-1)
|
||||
if self.bos_token_id < (num_tokens - 1):
|
||||
scores = tf.concat(
|
||||
(scores, tf.broadcast_to(-float("inf"), (batch_size, (num_tokens - 1) - self.bos_token_id))),
|
||||
axis=-1,
|
||||
)
|
||||
return scores
|
||||
|
||||
|
||||
class TFForcedEOSTokenLogitsProcessor(TFLogitsProcessor):
|
||||
r"""
|
||||
[`TFLogitsProcessor`] that enforces the specified token as the last generated token when `max_length` is reached.
|
||||
|
||||
Args:
|
||||
max_length (`int`):
|
||||
The maximum length of the sequence to be generated.
|
||||
eos_token_id (`int`):
|
||||
The id of the token to force as the last generated token when `max_length` is reached.
|
||||
"""
|
||||
|
||||
def __init__(self, max_length: int, eos_token_id: int):
|
||||
self.max_length = max_length
|
||||
if eos_token_id < 0:
|
||||
raise ValueError(f"The forced eos token id must be a non-negative integer, got {eos_token_id}")
|
||||
self.eos_token_id = eos_token_id
|
||||
|
||||
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
|
||||
if cur_len == self.max_length - 1:
|
||||
batch_size, num_tokens = scores.shape
|
||||
# sets the score to 0 in the eos_token_id column
|
||||
scores = tf.zeros((batch_size, 1))
|
||||
# sets the score to -inf everywhere else
|
||||
if self.eos_token_id > 0:
|
||||
scores = tf.concat((tf.broadcast_to(-float("inf"), (batch_size, self.eos_token_id)), scores), axis=-1)
|
||||
if self.eos_token_id < (num_tokens - 1):
|
||||
scores = tf.concat(
|
||||
(scores, tf.broadcast_to(-float("inf"), (batch_size, (num_tokens - 1) - self.eos_token_id))),
|
||||
axis=-1,
|
||||
)
|
||||
return scores
|
||||
|
@ -16,12 +16,15 @@
|
||||
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from .generation_tf_logits_process import (
|
||||
TFForcedBOSTokenLogitsProcessor,
|
||||
TFForcedEOSTokenLogitsProcessor,
|
||||
TFLogitsProcessorList,
|
||||
TFMinLengthLogitsProcessor,
|
||||
TFNoBadWordsLogitsProcessor,
|
||||
@ -560,7 +563,7 @@ class TFGenerationMixin:
|
||||
num_beams = num_beams if num_beams is not None else self.config.num_beams
|
||||
do_sample = do_sample if do_sample is not None else self.config.do_sample
|
||||
|
||||
if num_beams == 1:
|
||||
if do_sample is False or num_beams == 1:
|
||||
return self._generate(
|
||||
input_ids=input_ids,
|
||||
max_length=max_length,
|
||||
@ -586,6 +589,8 @@ class TFGenerationMixin:
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
forced_bos_token_id=forced_bos_token_id,
|
||||
forced_eos_token_id=forced_eos_token_id,
|
||||
)
|
||||
|
||||
# We cannot generate if the model does not have a LM head
|
||||
@ -1377,8 +1382,8 @@ class TFGenerationMixin:
|
||||
the target language token.
|
||||
forced_eos_token_id (`int`, *optional*):
|
||||
The id of the token to force as the last generated token when `max_length` is reached.
|
||||
model_specific_kwargs:
|
||||
Additional model specific kwargs will be forwarded to the `forward` function of the model.
|
||||
model_kwargs:
|
||||
Additional model specific kwargs will be forwarded to the `call` function of the model.
|
||||
|
||||
Return:
|
||||
[`~utils.ModelOutput`] or `tf.Tensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` or when
|
||||
@ -1452,12 +1457,20 @@ class TFGenerationMixin:
|
||||
# 1. Set generation parameters if not already defined
|
||||
max_length = max_length if max_length is not None else self.config.max_length
|
||||
min_length = min_length if min_length is not None else self.config.min_length
|
||||
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
|
||||
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
|
||||
|
||||
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||
|
||||
forced_bos_token_id = (
|
||||
forced_bos_token_id if forced_bos_token_id is not None else self.config.forced_bos_token_id
|
||||
)
|
||||
forced_eos_token_id = (
|
||||
forced_eos_token_id if forced_eos_token_id is not None else self.config.forced_eos_token_id
|
||||
)
|
||||
|
||||
output_scores = output_scores if output_scores is not None else self.config.output_scores
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@ -1489,10 +1502,13 @@ class TFGenerationMixin:
|
||||
model_kwargs["output_hidden_states"] = output_hidden_states
|
||||
if use_cache is not None:
|
||||
model_kwargs["use_cache"] = use_cache
|
||||
if attention_mask is not None:
|
||||
model_kwargs["attention_mask"] = attention_mask
|
||||
|
||||
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.call).parameters.keys())
|
||||
requires_attention_mask = "encoder_outputs" not in model_kwargs
|
||||
|
||||
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask:
|
||||
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
|
||||
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(input_ids, pad_token_id)
|
||||
|
||||
# 4. Prepare model inputs which will be used for auto-regressive generation
|
||||
@ -1519,6 +1535,7 @@ class TFGenerationMixin:
|
||||
# TODO(Matt, Joao, Patrick) - add more use cases here
|
||||
is_greedy_gen_mode = (num_beams == 1) and do_sample is False
|
||||
is_sample_gen_mode = (num_beams == 1) and do_sample is True
|
||||
is_beam_gen_mode = (num_beams > 1) and do_sample is False
|
||||
|
||||
# 6. prepare distribution pre_processing samplers
|
||||
logits_processor = self._get_logits_processor(
|
||||
@ -1526,7 +1543,10 @@ class TFGenerationMixin:
|
||||
no_repeat_ngram_size=no_repeat_ngram_size,
|
||||
bad_words_ids=bad_words_ids,
|
||||
min_length=min_length,
|
||||
max_length=max_length,
|
||||
eos_token_id=eos_token_id,
|
||||
forced_bos_token_id=forced_bos_token_id,
|
||||
forced_eos_token_id=forced_eos_token_id,
|
||||
)
|
||||
|
||||
# 7. go into different generation modes
|
||||
@ -1571,18 +1591,62 @@ class TFGenerationMixin:
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
elif is_beam_gen_mode:
|
||||
if num_beams < num_return_sequences:
|
||||
raise ValueError(
|
||||
"Greedy beam search decoding cannot return more sequences than it has beams. Please set "
|
||||
f"num_beams >= num_return_sequences, got {num_beams} and {num_return_sequences} (respectivelly)"
|
||||
)
|
||||
|
||||
# 8. broadcast inputs to the desired number of beams
|
||||
input_ids = self._expand_to_num_beams(input_ids, num_beams=num_beams)
|
||||
|
||||
if "encoder_outputs" in model_kwargs:
|
||||
model_kwargs["encoder_outputs"]["last_hidden_state"] = self._expand_to_num_beams(
|
||||
model_kwargs["encoder_outputs"]["last_hidden_state"], num_beams=num_beams
|
||||
)
|
||||
|
||||
if "attention_mask" in model_kwargs:
|
||||
model_kwargs["attention_mask"] = self._expand_to_num_beams(
|
||||
model_kwargs["attention_mask"], num_beams=num_beams
|
||||
)
|
||||
|
||||
# 9. run beam search
|
||||
return self.beam_search(
|
||||
input_ids,
|
||||
max_length=max_length,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
length_penalty=length_penalty,
|
||||
early_stopping=early_stopping,
|
||||
logits_processor=logits_processor,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
num_return_sequences=num_return_sequences,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
# TODO(Matt, Joao, Patrick) - add more sub-generation methods here
|
||||
raise NotImplementedError("Beam sampling is currently not implemented.")
|
||||
|
||||
@staticmethod
|
||||
def _expand_to_num_beams(tensor: tf.Tensor, num_beams: int) -> tf.Tensor:
|
||||
return tf.broadcast_to(tensor[:, None], (tensor.shape[0], num_beams) + tensor.shape[1:])
|
||||
|
||||
def _prepare_attention_mask_for_generation(
|
||||
self,
|
||||
input_ids: tf.Tensor,
|
||||
inputs: tf.Tensor,
|
||||
pad_token_id: int,
|
||||
) -> tf.Tensor:
|
||||
# prepare `attention_mask` if not passed
|
||||
if (pad_token_id is not None) and tf.math.reduce_any(input_ids == pad_token_id):
|
||||
return tf.cast(tf.math.not_equal(input_ids, pad_token_id), dtype=tf.int32)
|
||||
is_input_ids = len(inputs.shape) == 2 and inputs.dtype in (tf.int32, tf.int64)
|
||||
is_pad_token_in_inputs = (pad_token_id is not None) and tf.math.reduce_any(inputs == pad_token_id)
|
||||
# Check if input is input_ids and padded -> only then is attention_mask defined
|
||||
if is_input_ids and is_pad_token_in_inputs:
|
||||
return tf.cast(tf.math.not_equal(inputs, pad_token_id), dtype=tf.int32)
|
||||
else:
|
||||
return tf.ones(input_ids.shape[:2], dtype=tf.int32)
|
||||
return tf.ones(inputs.shape[:2], dtype=tf.int32)
|
||||
|
||||
def _prepare_encoder_decoder_kwargs_for_generation(self, input_ids: tf.Tensor, model_kwargs) -> Dict[str, Any]:
|
||||
def _prepare_encoder_decoder_kwargs_for_generation(self, inputs_tensor: tf.Tensor, model_kwargs) -> Dict[str, Any]:
|
||||
# get encoder and store encoder outputs
|
||||
encoder = self.get_encoder()
|
||||
|
||||
@ -1595,11 +1659,9 @@ class TFGenerationMixin:
|
||||
}
|
||||
|
||||
# vision models don't use `attention_mask`.
|
||||
signature = dict(inspect.signature(encoder.call).parameters)
|
||||
if "attention_mask" not in signature:
|
||||
encoder_kwargs.pop("attention_mask")
|
||||
|
||||
encoder_outputs = encoder(input_ids, **encoder_kwargs)
|
||||
encoder_kwargs["return_dict"] = True
|
||||
encoder_kwargs[self.main_input_name] = inputs_tensor
|
||||
encoder_outputs = encoder(**encoder_kwargs)
|
||||
model_kwargs["encoder_outputs"] = encoder_outputs
|
||||
|
||||
return model_kwargs
|
||||
@ -1757,7 +1819,10 @@ class TFGenerationMixin:
|
||||
no_repeat_ngram_size: int,
|
||||
bad_words_ids: List[List[int]],
|
||||
min_length: int,
|
||||
max_length: int,
|
||||
eos_token_id: int,
|
||||
forced_bos_token_id: int,
|
||||
forced_eos_token_id: int,
|
||||
) -> TFLogitsProcessorList:
|
||||
"""
|
||||
This class returns a [`TFLogitsProcessorList`] list object that contains all relevant [`TFLogitsProcessor`]
|
||||
@ -1781,6 +1846,10 @@ class TFGenerationMixin:
|
||||
processors.append(TFNoBadWordsLogitsProcessor(bad_words_ids, eos_token_id))
|
||||
if min_length is not None and eos_token_id is not None and min_length > 0:
|
||||
processors.append(TFMinLengthLogitsProcessor(min_length, eos_token_id))
|
||||
if forced_bos_token_id is not None:
|
||||
processors.append(TFForcedBOSTokenLogitsProcessor(forced_bos_token_id))
|
||||
if forced_eos_token_id is not None:
|
||||
processors.append(TFForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id))
|
||||
|
||||
return processors
|
||||
|
||||
@ -1949,7 +2018,7 @@ class TFGenerationMixin:
|
||||
if not use_xla:
|
||||
input_ids = tf.reshape(generated.concat(), (-1, batch_size))
|
||||
input_ids = tf.transpose(input_ids[: current_pos[0]])
|
||||
next_tokens_scores = logits_processor(input_ids, next_token_logits)
|
||||
next_tokens_scores = logits_processor(input_ids, next_token_logits, cur_len=current_pos[0])
|
||||
|
||||
# argmax
|
||||
next_tokens = tf.argmax(next_tokens_scores, axis=-1, output_type=tf.int32)
|
||||
@ -2064,7 +2133,7 @@ class TFGenerationMixin:
|
||||
input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
|
||||
The sequence used as a prompt for the generation.
|
||||
logits_processor (`TFLogitsProcessorList`, *optional*):
|
||||
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`TFLogitsProcessor`]
|
||||
An instance of [`TFLogitsProcessorList`]. List of instances of class derived from [`TFLogitsProcessor`]
|
||||
used to modify the prediction scores of the language modeling head applied at each generation step.
|
||||
logits_warper (`TFLogitsProcessorList`, *optional*):
|
||||
An instance of [`TFLogitsProcessorList`]. List of instances of class derived from [`TFLogitsWarper`]
|
||||
@ -2184,7 +2253,7 @@ class TFGenerationMixin:
|
||||
next_token_logits = outputs.logits[:, -1, :]
|
||||
|
||||
# pre-process distribution
|
||||
next_token_scores = logits_processor(input_ids, next_token_logits)
|
||||
next_token_scores = logits_processor(input_ids, next_token_logits, cur_len=cur_len)
|
||||
next_token_scores = logits_warper(input_ids, next_token_scores)
|
||||
|
||||
# Store scores, attentions and hidden_states when required
|
||||
@ -2259,6 +2328,527 @@ class TFGenerationMixin:
|
||||
else:
|
||||
return input_ids
|
||||
|
||||
def beam_search(
|
||||
self,
|
||||
input_ids: tf.Tensor,
|
||||
max_length: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
length_penalty: Optional[float] = None,
|
||||
early_stopping: Optional[bool] = None,
|
||||
logits_processor: Optional[TFLogitsProcessorList] = None,
|
||||
num_return_sequences: Optional[int] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_scores: Optional[bool] = None,
|
||||
return_dict_in_generate: Optional[bool] = None,
|
||||
**model_kwargs,
|
||||
) -> Union[TFBeamSearchOutput, tf.Tensor]:
|
||||
r"""
|
||||
Generates sequences for models with a language modeling head using beam search with multinomial sampling.
|
||||
|
||||
Parameters:
|
||||
|
||||
input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
|
||||
The sequence used as a prompt for the generation.
|
||||
max_length (`int`, *optional*, defaults to 20):
|
||||
The maximum length of the sequence to be generated.
|
||||
pad_token_id (`int`, *optional*):
|
||||
The id of the *padding* token.
|
||||
eos_token_id (`int`, *optional*):
|
||||
The id of the *end-of-sequence* token.
|
||||
length_penalty (`float`, *optional*, defaults to 1.0):
|
||||
Exponential penalty to the length. 1.0 means no penalty.
|
||||
early_stopping (`bool`, *optional*, defaults to `False`):
|
||||
Whether to stop the beam search when at least `num_beams` sentences are finished per batch or not.
|
||||
logits_processor (`[TFLogitsProcessorList]`, *optional*):
|
||||
An instance of [`TFLogitsProcessorList`]. List of instances of class derived from [`TFLogitsProcessor`]
|
||||
used to modify the prediction scores of the language modeling head applied at each generation step.
|
||||
num_return_sequences(`int`, *optional*, defaults to 1):
|
||||
The number of independently computed returned sequences for each element in the batch.
|
||||
output_attentions (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more details.
|
||||
output_hidden_states (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
||||
for more details.
|
||||
return_dict_in_generate (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
|
||||
model_kwargs:
|
||||
Additional model specific kwargs will be forwarded to the `call` function of the model. If model is an
|
||||
encoder-decoder model the kwargs should include `encoder_outputs`.
|
||||
|
||||
Return:
|
||||
[`~generation_tf_utils.TFBeamSearchDecoderOnlyOutput`],
|
||||
[`~generation_tf_utils.TFBeamSearchEncoderDecoderOutput`] or `tf.Tensor`: A `tf.Tensor` containing the
|
||||
generated tokens (default behaviour) or a [`~generation_tf_utils.TFBeamSearchDecoderOnlyOutput`] if
|
||||
`model.config.is_encoder_decoder=False` and `return_dict_in_generate=True` or a
|
||||
[`~generation_tf_utils.TFBeamSearchEncoderDecoderOutput`] if `model.config.is_encoder_decoder=True`.
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from transformers import (
|
||||
... AutoTokenizer,
|
||||
... TFAutoModelForSeq2SeqLM,
|
||||
... TFLogitsProcessorList,
|
||||
... TFMinLengthLogitsProcessor,
|
||||
... )
|
||||
>>> import tensorflow as tf
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("t5-base")
|
||||
>>> model = TFAutoModelForSeq2SeqLM.from_pretrained("t5-base")
|
||||
|
||||
>>> encoder_input_str = "translate English to German: How old are you?"
|
||||
>>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="tf").input_ids
|
||||
|
||||
>>> # lets run beam search using 3 beams
|
||||
>>> num_beams = 3
|
||||
>>> # define decoder start token ids
|
||||
>>> input_ids = tf.ones((num_beams, 1), dtype=tf.int64)
|
||||
>>> input_ids = input_ids * model.config.decoder_start_token_id
|
||||
|
||||
>>> # add encoder_outputs to model keyword arguments
|
||||
>>> model_kwargs = {
|
||||
... "encoder_outputs": model.get_encoder()(
|
||||
... tf.repeat(encoder_input_ids, num_beams, axis=0), return_dict=True
|
||||
... )
|
||||
... }
|
||||
|
||||
>>> # instantiate logits processors
|
||||
>>> logits_processor = TFLogitsProcessorList(
|
||||
... [TFMinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id)]
|
||||
... )
|
||||
|
||||
>>> outputs = model.beam_search(input_ids, logits_processor=logits_processor, **model_kwargs)
|
||||
|
||||
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
|
||||
```"""
|
||||
|
||||
def flatten_beam_dim(tensor, batch_axis=0):
|
||||
"""Flattens the first two dimensions of a non-scalar array."""
|
||||
# ignore scalars (e.g. cache index)
|
||||
if tf.rank(tensor) == 0:
|
||||
return tensor
|
||||
return tf.reshape(
|
||||
tensor,
|
||||
tensor.shape[:batch_axis]
|
||||
+ [tensor.shape[batch_axis] * tensor.shape[batch_axis + 1]]
|
||||
+ tensor.shape[batch_axis + 2 :],
|
||||
)
|
||||
|
||||
def unflatten_beam_dim(tensor, batch_size, num_beams, batch_axis=0):
|
||||
"""Unflattens the first, flat batch*beam dimension of a non-scalar array."""
|
||||
# ignore scalars (e.g. cache index)
|
||||
if tf.rank(tensor) == 0:
|
||||
return tensor
|
||||
return tf.reshape(
|
||||
tensor, tensor.shape[:batch_axis] + [batch_size, num_beams] + tensor.shape[batch_axis + 1 :]
|
||||
)
|
||||
|
||||
def gather_beams(nested, beam_indices, batch_axis=0):
|
||||
"""Gathers the beam slices indexed by beam_indices into new beam array."""
|
||||
|
||||
def gather_fn(tensor):
|
||||
# ignore scalars (e.g. cache index)
|
||||
if tf.rank(tensor) == 0:
|
||||
return tensor
|
||||
else:
|
||||
if batch_axis > 0:
|
||||
# pushes all dimentions before the batch to the end, so we get (batch, beam_id, ...)
|
||||
perm = [axis for axis in range(tf.rank(tensor)) if axis >= batch_axis] + list(
|
||||
range(batch_axis)
|
||||
)
|
||||
tensor = tf.transpose(tensor, perm=perm)
|
||||
|
||||
gathered_tensor = tf.gather(params=tensor, indices=beam_indices, axis=1, batch_dims=1)
|
||||
if batch_axis > 0:
|
||||
# transposes back to the original dimensions
|
||||
perm = [axis for axis in range(tf.rank(tensor)) if axis >= batch_axis] + list(
|
||||
range(batch_axis)
|
||||
)
|
||||
perm = tf.math.invert_permutation(perm)
|
||||
gathered_tensor = tf.transpose(gathered_tensor, perm=perm)
|
||||
|
||||
return gathered_tensor
|
||||
|
||||
return tf.nest.map_structure(gather_fn, nested)
|
||||
|
||||
# 1. init beam_search values
|
||||
logits_processor = logits_processor if logits_processor is not None else TFLogitsProcessorList()
|
||||
|
||||
max_length = max_length if max_length is not None else self.config.max_length
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||
num_return_sequences = (
|
||||
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
|
||||
)
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
output_scores = output_scores if output_scores is not None else self.config.output_scores
|
||||
return_dict_in_generate = (
|
||||
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
|
||||
)
|
||||
|
||||
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
|
||||
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
|
||||
|
||||
use_xla = not tf.executing_eagerly()
|
||||
# TODO (Joao): fix cache format or find programatic way to detect cache index
|
||||
# GPT2 and other models has a slightly different cache structure, with a different batch axis
|
||||
model_name = str(self.decoder) if "EncoderDecoder" in str(self) else str(self)
|
||||
cache_batch_axis = 1 if any([model_prefix in model_name for model_prefix in ("TFGPT2", "TFCTRL")]) else 0
|
||||
|
||||
# 2. init `attentions`, `hidden_states`, and `scores` tuples
|
||||
scores = [] if (return_dict_in_generate and output_scores) else None
|
||||
decoder_attentions = [] if (return_dict_in_generate and output_attentions) else None
|
||||
cross_attentions = [] if (return_dict_in_generate and output_attentions) else None
|
||||
decoder_hidden_states = [] if (return_dict_in_generate and output_hidden_states) else None
|
||||
|
||||
# 3. init tensors to use for "xla-compileable" generate function
|
||||
batch_size, num_beams, cur_len = input_ids.shape
|
||||
|
||||
# per batch, beam-item holding current token in loop, pre-populated with `pad_token_id`
|
||||
sequences = tf.TensorArray(
|
||||
element_shape=(batch_size, num_beams),
|
||||
dtype=tf.int32,
|
||||
dynamic_size=False,
|
||||
size=max_length,
|
||||
clear_after_read=False,
|
||||
)
|
||||
running_sequences = tf.TensorArray(
|
||||
element_shape=(batch_size, num_beams),
|
||||
dtype=tf.int32,
|
||||
dynamic_size=False,
|
||||
size=max_length,
|
||||
clear_after_read=False,
|
||||
)
|
||||
intermediary_running_sequences = tf.TensorArray(
|
||||
element_shape=(batch_size, num_beams * 2),
|
||||
dtype=tf.int32,
|
||||
dynamic_size=False,
|
||||
size=max_length,
|
||||
clear_after_read=False,
|
||||
)
|
||||
for i in range(max_length):
|
||||
sequences = sequences.write(i, tf.broadcast_to(pad_token_id, (batch_size, num_beams)))
|
||||
running_sequences = running_sequences.write(i, tf.broadcast_to(pad_token_id, (batch_size, num_beams)))
|
||||
intermediary_running_sequences = intermediary_running_sequences.write(
|
||||
i, tf.broadcast_to(pad_token_id, (batch_size, num_beams * 2))
|
||||
)
|
||||
|
||||
# write prompt to running_sequences
|
||||
for i in range(cur_len):
|
||||
running_sequences = running_sequences.write(i, input_ids[:, :, i])
|
||||
|
||||
# per batch,beam-item state bit indicating if sentence has finished.
|
||||
is_sent_finished = tf.zeros((batch_size, num_beams), dtype=tf.bool)
|
||||
|
||||
# per batch, beam-item score, logprobs
|
||||
running_scores = tf.tile(
|
||||
tf.expand_dims(tf.convert_to_tensor([0.0] + [-1.0e9] * (num_beams - 1)), axis=0), [batch_size, 1]
|
||||
)
|
||||
scores = tf.ones((batch_size, num_beams)) * -1.0e9
|
||||
|
||||
# flatten beam dim
|
||||
if "encoder_outputs" in model_kwargs:
|
||||
model_kwargs["encoder_outputs"]["last_hidden_state"] = flatten_beam_dim(
|
||||
model_kwargs["encoder_outputs"]["last_hidden_state"]
|
||||
)
|
||||
if "attention_mask" in model_kwargs:
|
||||
model_kwargs["attention_mask"] = flatten_beam_dim(model_kwargs["attention_mask"])
|
||||
|
||||
# 4. define "xla-compile-able" stop-condition and auto-regressive function
|
||||
# define stop-condition and auto-regressive function
|
||||
def beam_search_cond_fn(
|
||||
cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, model_kwargs
|
||||
):
|
||||
"""
|
||||
Beam Search termination condition function -- halts the generation loop if any of these conditions becomes
|
||||
False
|
||||
"""
|
||||
# 1. is less than max length?
|
||||
not_max_length_yet = cur_len < max_length
|
||||
|
||||
# 2. can the new beams still improve?
|
||||
best_running_score = running_scores[:, :1] / (max_length**length_penalty)
|
||||
worst_finished_score = tf.where(
|
||||
is_sent_finished, tf.math.reduce_min(scores, axis=1, keepdims=True), -1.0e9
|
||||
)
|
||||
improvement_still_possible = tf.math.reduce_all(worst_finished_score < best_running_score)
|
||||
|
||||
# 3. is there still a beam that has not finished?
|
||||
still_open_beam = ~(tf.math.reduce_all(is_sent_finished) & early_stopping)
|
||||
|
||||
return not_max_length_yet & (still_open_beam | improvement_still_possible)
|
||||
|
||||
def beam_search_body_fn(
|
||||
cur_len,
|
||||
running_sequences,
|
||||
running_scores,
|
||||
sequences,
|
||||
scores,
|
||||
is_sent_finished,
|
||||
model_kwargs,
|
||||
input_ids_length=1,
|
||||
intermediary_running_sequences=None,
|
||||
):
|
||||
"""
|
||||
Beam Search iterative update function -- each iteration adds a new token and updates the best sequences
|
||||
seen so far
|
||||
"""
|
||||
# TODO (joao): this loop is probably faster with gather/scatters, instead of using `tf.TensorArray`.
|
||||
# Alternativelly, attempt to rewrite function with permuted axis, when enabling XLA.
|
||||
|
||||
# 1. Forward current tokens
|
||||
|
||||
# TF places the dynamic dimension (seq_len) in the first axis, we want it in the last
|
||||
running_sequences_seq_last = tf.transpose(running_sequences.stack(), perm=[1, 2, 0])
|
||||
input_token = tf.slice(
|
||||
running_sequences_seq_last,
|
||||
(0, 0, cur_len - input_ids_length),
|
||||
(batch_size, num_beams, input_ids_length),
|
||||
)
|
||||
model_inputs = self.prepare_inputs_for_generation(
|
||||
flatten_beam_dim(input_token), use_xla=use_xla, **model_kwargs
|
||||
)
|
||||
model_outputs = self(
|
||||
**model_inputs,
|
||||
return_dict=True,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
logits = unflatten_beam_dim(model_outputs.logits[:, -1], batch_size, num_beams)
|
||||
|
||||
# Store scores, attentions and hidden_states when required
|
||||
if not use_xla and return_dict_in_generate:
|
||||
if output_scores:
|
||||
scores.append(model_outputs.logits[:, -1])
|
||||
if output_attentions and self.config.is_encoder_decoder:
|
||||
decoder_attentions.append(model_outputs.decoder_attentions)
|
||||
elif output_attentions and not self.config.is_encoder_decoder:
|
||||
decoder_attentions.append(model_outputs.attentions)
|
||||
if self.config.is_encoder_decoder:
|
||||
cross_attentions.append(model_outputs.cross_attentions)
|
||||
|
||||
if output_hidden_states and self.config.is_encoder_decoder:
|
||||
decoder_hidden_states.append(model_outputs.decoder_hidden_states)
|
||||
elif output_hidden_states and self.config.is_encoder_decoder:
|
||||
decoder_hidden_states.append(model_outputs.hidden_states)
|
||||
|
||||
# 2. Compute log probs
|
||||
# get log probabilities from logits, process logits with processors (*e.g.* min_length, ...), and
|
||||
# add new logprobs to existing running logprobs scores.
|
||||
log_probs = tf.nn.log_softmax(logits)
|
||||
log_probs = logits_processor(
|
||||
flatten_beam_dim(running_sequences_seq_last), flatten_beam_dim(log_probs), cur_len=cur_len
|
||||
)
|
||||
log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams)
|
||||
log_probs = log_probs + tf.expand_dims(running_scores, axis=2)
|
||||
vocab_size = log_probs.shape[2]
|
||||
log_probs = tf.reshape(log_probs, (batch_size, num_beams * vocab_size))
|
||||
|
||||
# 3. Retrieve top-K
|
||||
# Each item in batch has num_beams * vocab_size candidate sequences. For each item, get the top 2*k
|
||||
# candidates with the highest log-probabilities. We gather the top 2*K beams here so that even if the
|
||||
# best K sequences reach EOS simultaneously, we have another K sequences remaining to continue the live
|
||||
# beam search.
|
||||
# Gather the top 2*K scores from _all_ beams.
|
||||
# Gather 2*k top beams.
|
||||
# Recover the beam index by floor division.
|
||||
# Recover token id by modulo division and expand Id array for broadcasting.
|
||||
# Update sequences for the 2*K top-k new sequences.
|
||||
beams_to_keep = 2 * num_beams
|
||||
topk_log_probs, topk_indices = tf.math.top_k(log_probs, k=beams_to_keep)
|
||||
topk_beam_indices = topk_indices // vocab_size
|
||||
topk_running_sequences_seq_last = gather_beams(running_sequences_seq_last, topk_beam_indices)
|
||||
topk_ids = topk_indices % vocab_size
|
||||
|
||||
# writes the new token
|
||||
intermediary_running_sequences = intermediary_running_sequences.unstack(
|
||||
tf.transpose(topk_running_sequences_seq_last, perm=[2, 0, 1])
|
||||
)
|
||||
topk_sequences = intermediary_running_sequences.write(cur_len, topk_ids)
|
||||
topk_sequences_seq_last = tf.transpose(topk_sequences.stack(), perm=[1, 2, 0])
|
||||
|
||||
# 4. Check which sequences have ended
|
||||
# Update current sequences: Did the top `num_beams` sequences reach an end marker?
|
||||
# To prevent these just finished sequences from being added to the current sequences
|
||||
# set of active beam search sequences, set their log probs to a very large negative value.
|
||||
eos_in_next_token = topk_sequences_seq_last[:, :, cur_len] == eos_token_id
|
||||
if eos_token_id is None:
|
||||
eos_in_next_token = tf.broadcast_to(eos_in_next_token, topk_sequences_seq_last[:, :, cur_len].shape)
|
||||
did_topk_just_finished = eos_in_next_token & tf.broadcast_to(
|
||||
tf.concat((tf.ones((num_beams), dtype=tf.bool), tf.zeros((num_beams), dtype=tf.bool)), axis=0),
|
||||
eos_in_next_token.shape,
|
||||
)
|
||||
|
||||
# non-top `num_beams` eos tokens can't be used to finish a beam, but the others can't be used in the next
|
||||
# running sentences either
|
||||
running_topk_log_probs = topk_log_probs + tf.cast(eos_in_next_token, tf.float32) * -1.0e9
|
||||
|
||||
# 5. Get running sequences scores for next
|
||||
# Determine the top k beam indices (from top 2*k beams) from log probs and gather top k beams
|
||||
# (from top 2*k beams).
|
||||
next_topk_indices = tf.math.top_k(running_topk_log_probs, k=num_beams)[1]
|
||||
next_running_sequences_seq_last, next_running_scores = gather_beams(
|
||||
[topk_sequences_seq_last, running_topk_log_probs], next_topk_indices
|
||||
)
|
||||
|
||||
# 6. Process topk logits
|
||||
# Further process log probs:
|
||||
# - add length penalty
|
||||
# - make sure no scores can be added anymore if beam is full
|
||||
# - make sure still running sequences cannot be chosen as finalized beam
|
||||
topk_log_probs = topk_log_probs / (cur_len**length_penalty)
|
||||
beams_in_batch_are_full = (
|
||||
tf.broadcast_to(
|
||||
tf.math.reduce_all(is_sent_finished, axis=-1, keepdims=True), did_topk_just_finished.shape
|
||||
)
|
||||
& early_stopping
|
||||
)
|
||||
add_penalty = ~did_topk_just_finished | beams_in_batch_are_full
|
||||
topk_log_probs += tf.cast(add_penalty, tf.float32) * -1.0e9
|
||||
|
||||
# 7. Get scores, sequences, is sentence finished for next.
|
||||
# Combine sequences, scores, and flags along the beam dimension and compare new finished sequence scores
|
||||
# to existing finished scores and select the best from the new set of beams
|
||||
sequences_seq_last = tf.transpose(sequences.stack(), perm=[1, 2, 0])
|
||||
merged_sequences = tf.concat([sequences_seq_last, topk_sequences_seq_last], axis=1)
|
||||
merged_scores = tf.concat([scores, topk_log_probs], axis=1)
|
||||
merged_is_sent_finished = tf.concat([is_sent_finished, did_topk_just_finished], axis=1)
|
||||
topk_merged_indices = tf.math.top_k(merged_scores, k=num_beams)[1]
|
||||
next_sequences_seq_last, next_scores, next_is_sent_finished = gather_beams(
|
||||
[merged_sequences, merged_scores, merged_is_sent_finished], topk_merged_indices
|
||||
)
|
||||
|
||||
# 8. Prepare data for the next iteration
|
||||
# Determine the top k beam indices from the original set of all beams. With these, gather the top k
|
||||
# beam-associated caches.
|
||||
if "past_key_values" in model_outputs:
|
||||
cache = tf.nest.map_structure(
|
||||
lambda tensor: unflatten_beam_dim(tensor, batch_size, num_beams, batch_axis=cache_batch_axis),
|
||||
model_outputs.past_key_values,
|
||||
)
|
||||
next_running_indices = gather_beams(topk_beam_indices, next_topk_indices)
|
||||
next_cache = gather_beams(cache, next_running_indices, batch_axis=cache_batch_axis)
|
||||
model_outputs["past_key_values"] = tf.nest.map_structure(
|
||||
lambda tensor: flatten_beam_dim(tensor, batch_axis=cache_batch_axis), next_cache
|
||||
)
|
||||
|
||||
if use_xla:
|
||||
next_model_kwargs = self._update_model_kwargs_for_xla_generation(
|
||||
model_outputs, model_kwargs, cur_len, max_length
|
||||
)
|
||||
else:
|
||||
next_model_kwargs = self._update_model_kwargs_for_generation(
|
||||
model_outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
||||
)
|
||||
|
||||
# if we don't cache past key values we need the whole input
|
||||
if model_kwargs.get("past", None) is None:
|
||||
input_ids_length = cur_len + 1
|
||||
# let's throw out `past` since we don't want `None` tensors
|
||||
model_kwargs.pop("past", None)
|
||||
|
||||
# 9. Prepare the `tf.TensorArray` for the next iteration
|
||||
next_sequences = sequences.unstack(tf.transpose(next_sequences_seq_last, perm=[2, 0, 1]))
|
||||
next_running_sequences = running_sequences.unstack(
|
||||
tf.transpose(next_running_sequences_seq_last, perm=[2, 0, 1])
|
||||
)
|
||||
|
||||
return (
|
||||
cur_len + 1,
|
||||
next_running_sequences,
|
||||
next_running_scores,
|
||||
next_sequences,
|
||||
next_scores,
|
||||
next_is_sent_finished,
|
||||
next_model_kwargs,
|
||||
)
|
||||
|
||||
# 5. run generation
|
||||
# Adds the `intermediary_running_sequences` TensorArray into the body, needed as a scratchpad
|
||||
beam_search_body_fn = partial(
|
||||
beam_search_body_fn, intermediary_running_sequences=intermediary_running_sequences
|
||||
)
|
||||
|
||||
# 1st generation step has to be run before to initialize `past`
|
||||
beam_search_body_fn_first_iter = partial(beam_search_body_fn, input_ids_length=cur_len)
|
||||
(
|
||||
cur_len,
|
||||
running_sequences,
|
||||
running_scores,
|
||||
sequences,
|
||||
scores,
|
||||
is_sent_finished,
|
||||
model_kwargs,
|
||||
) = beam_search_body_fn_first_iter(
|
||||
cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, model_kwargs
|
||||
)
|
||||
|
||||
# 2-to-n generation steps can then be run in autoregressive fashion (only in case 1st generation step does
|
||||
# NOT yield EOS token though)
|
||||
if beam_search_cond_fn(
|
||||
cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, model_kwargs
|
||||
):
|
||||
maximum_iterations = max_length - cur_len
|
||||
cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, _ = tf.while_loop(
|
||||
beam_search_cond_fn,
|
||||
beam_search_body_fn,
|
||||
(cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, model_kwargs),
|
||||
maximum_iterations=maximum_iterations,
|
||||
)
|
||||
|
||||
# 6. prepare outputs
|
||||
# convert the sequneces to tf.Tensor with shape (batch_size, num_beams, seq_len)
|
||||
sequences_seq_last = tf.transpose(sequences.stack(), perm=[1, 2, 0])
|
||||
running_sequences_seq_last = tf.transpose(running_sequences.stack(), perm=[1, 2, 0])
|
||||
|
||||
# Account for the edge-case where there are no finished sequences for a particular batch item. If so, return
|
||||
# running sequences for that batch item.
|
||||
none_finished = tf.math.reduce_any(is_sent_finished, axis=1)
|
||||
sequences_seq_last = tf.where(none_finished[:, None, None], sequences_seq_last, running_sequences_seq_last)
|
||||
scores = tf.where(none_finished[:, None], scores, running_scores)
|
||||
|
||||
# Take best beams for each batch (the score is sorted in ascending order)
|
||||
sequences_seq_last = flatten_beam_dim(sequences_seq_last[:, :num_return_sequences, :])
|
||||
scores = flatten_beam_dim(scores[:, :num_return_sequences])
|
||||
|
||||
if not use_xla:
|
||||
# Cut for backward compatibility
|
||||
sequences_seq_last = sequences_seq_last[:, :cur_len]
|
||||
|
||||
if return_dict_in_generate:
|
||||
if self.config.is_encoder_decoder:
|
||||
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
||||
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
|
||||
encoder_hidden_states = (
|
||||
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
|
||||
)
|
||||
|
||||
return TFBeamSearchEncoderDecoderOutput(
|
||||
sequences=sequences_seq_last,
|
||||
scores=scores,
|
||||
encoder_attentions=encoder_attentions,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
decoder_attentions=decoder_attentions,
|
||||
cross_attentions=cross_attentions,
|
||||
decoder_hidden_states=decoder_hidden_states,
|
||||
)
|
||||
else:
|
||||
return TFBeamSearchDecoderOnlyOutput(
|
||||
sequences=sequences_seq_last,
|
||||
scores=scores,
|
||||
attentions=decoder_attentions,
|
||||
hidden_states=decoder_hidden_states,
|
||||
)
|
||||
else:
|
||||
return sequences_seq_last
|
||||
|
||||
|
||||
def _create_next_token_logits_penalties(input_ids, logits, repetition_penalty):
|
||||
# create logit penalties for already seen input_ids
|
||||
@ -2445,7 +3035,6 @@ class BeamHypotheses(object):
|
||||
If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst
|
||||
one in the heap, then we are done with this sentence.
|
||||
"""
|
||||
|
||||
if len(self) < self.num_beams:
|
||||
return False
|
||||
elif self.early_stopping:
|
||||
|
@ -1245,7 +1245,10 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
|
||||
no_repeat_ngram_size=no_repeat_ngram_size,
|
||||
bad_words_ids=bad_words_ids,
|
||||
min_length=min_length,
|
||||
max_length=max_length,
|
||||
eos_token_id=eos_token_id,
|
||||
forced_bos_token_id=None,
|
||||
forced_eos_token_id=None,
|
||||
)
|
||||
model_kwargs["attention_mask"] = context_attention_mask
|
||||
|
||||
|
@ -17,6 +17,20 @@ class TensorFlowBenchmark(metaclass=DummyObject):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFForcedBOSTokenLogitsProcessor(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFForcedEOSTokenLogitsProcessor(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFLogitsProcessor(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
|
@ -472,14 +472,14 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
|
||||
logits_processor = ForcedEOSTokenLogitsProcessor(max_length=max_length, eos_token_id=eos_token_id)
|
||||
|
||||
# check that all scores are -inf except the eos_token_id when max_length is reached
|
||||
# check that all scores are -inf except the eos_token_id when max_length-1 is reached
|
||||
input_ids = ids_tensor((batch_size, 4), vocab_size=20)
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
scores = logits_processor(input_ids, scores)
|
||||
self.assertTrue(torch.isneginf(scores[:, eos_token_id + 1 :]).all())
|
||||
self.assertListEqual(scores[:, eos_token_id].tolist(), 4 * [0]) # score for eos_token_id should be zero
|
||||
|
||||
# check that eos_token_id is not forced if max_length is not reached
|
||||
# check that eos_token_id is not forced if max_length-1 is not reached
|
||||
input_ids = ids_tensor((batch_size, 3), vocab_size=20)
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
scores = logits_processor(input_ids, scores)
|
||||
|
@ -26,6 +26,8 @@ if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
from transformers.generation_tf_logits_process import (
|
||||
TFForcedBOSTokenLogitsProcessor,
|
||||
TFForcedEOSTokenLogitsProcessor,
|
||||
TFLogitsProcessorList,
|
||||
TFMinLengthLogitsProcessor,
|
||||
TFNoBadWordsLogitsProcessor,
|
||||
@ -43,7 +45,7 @@ if is_tf_available():
|
||||
@require_tf
|
||||
class TFLogitsProcessorTest(unittest.TestCase):
|
||||
def _get_uniform_logits(self, batch_size: int, length: int):
|
||||
scores = np.ones((batch_size, length), dtype=np.float32) / length
|
||||
scores = tf.ones((batch_size, length), dtype=tf.float32) / length
|
||||
return scores
|
||||
|
||||
def test_min_length_dist_processor(self):
|
||||
@ -54,15 +56,17 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
||||
min_dist_processor = TFMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
|
||||
|
||||
# check that min length is applied at length 5
|
||||
input_ids = ids_tensor((batch_size, 5), vocab_size=20)
|
||||
cur_len = 5
|
||||
input_ids = ids_tensor((batch_size, cur_len), vocab_size=20)
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
scores_before_min_length = min_dist_processor(input_ids, scores)
|
||||
scores_before_min_length = min_dist_processor(input_ids, scores, cur_len)
|
||||
self.assertListEqual(scores_before_min_length[:, eos_token_id].numpy().tolist(), 4 * [-float("inf")])
|
||||
|
||||
# check that min length is not applied anymore at length 15
|
||||
input_ids = ids_tensor((batch_size, 15), vocab_size=20)
|
||||
cur_len = 15
|
||||
input_ids = ids_tensor((batch_size, cur_len), vocab_size=20)
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
scores_before_min_length = min_dist_processor(input_ids, scores)
|
||||
scores_before_min_length = min_dist_processor(input_ids, scores, cur_len)
|
||||
self.assertFalse(tf.math.reduce_any(tf.math.is_inf(scores_before_min_length)).numpy())
|
||||
|
||||
def test_temperature_dist_warper(self):
|
||||
@ -72,8 +76,10 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
||||
scores = self._get_uniform_logits(batch_size=2, length=length)
|
||||
|
||||
# tweak scores to not be uniform anymore
|
||||
scores = scores.numpy()
|
||||
scores[1, 5] = (1 / length) + 0.1 # peak, 1st batch
|
||||
scores[1, 10] = (1 / length) - 0.4 # valley, 1st batch
|
||||
scores = tf.convert_to_tensor(scores)
|
||||
|
||||
# compute softmax
|
||||
probs = tf.nn.softmax(scores, axis=-1)
|
||||
@ -97,8 +103,11 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
||||
self.assertLess(tf.math.reduce_min(probs[1, :]), tf.math.reduce_min(warped_prob_smooth[1, :]))
|
||||
|
||||
def test_repetition_penalty_dist_process(self):
|
||||
input_ids = tf.constant([[0, 1], [5, 0]], dtype=tf.int32)
|
||||
vocab_size = 10
|
||||
cur_len = 2
|
||||
|
||||
input_ids = tf.constant([[0, 1], [5, 0]], dtype=tf.int32)
|
||||
self.assertEqual(cur_len, input_ids.shape[1])
|
||||
|
||||
scores = self._get_uniform_logits(batch_size=2, length=vocab_size)
|
||||
|
||||
@ -109,7 +118,7 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
||||
|
||||
rep_penalty_proc = TFRepetitionPenaltyLogitsProcessor(penalty=2.0)
|
||||
|
||||
scores = rep_penalty_proc(input_ids, tf.identity(scores))
|
||||
scores = rep_penalty_proc(input_ids, tf.identity(scores), cur_len)
|
||||
|
||||
# check that values were correctly changed
|
||||
self.assertAlmostEqual(scores[0, 0].numpy(), -(1 / vocab_size) * 2)
|
||||
@ -188,15 +197,18 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
||||
def test_no_repeat_ngram_dist_processor(self):
|
||||
vocab_size = 3
|
||||
batch_size = 2
|
||||
cur_len = 4
|
||||
|
||||
input_ids = tf.constant([[1, 1, 2, 1], [0, 1, 0, 1]], dtype=tf.int32)
|
||||
self.assertEqual(cur_len, input_ids.shape[1])
|
||||
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
|
||||
no_repeat_proc_2_gram = TFNoRepeatNGramLogitsProcessor(2)
|
||||
no_repeat_proc_3_gram = TFNoRepeatNGramLogitsProcessor(3)
|
||||
|
||||
filtered_scores_2_gram = no_repeat_proc_2_gram(input_ids, tf.identity(scores))
|
||||
filtered_scores_3_gram = no_repeat_proc_3_gram(input_ids, tf.identity(scores))
|
||||
filtered_scores_2_gram = no_repeat_proc_2_gram(input_ids, tf.identity(scores), cur_len)
|
||||
filtered_scores_3_gram = no_repeat_proc_3_gram(input_ids, tf.identity(scores), cur_len)
|
||||
|
||||
# 2-gram would forbid 2nd and 3rd token (1,2) at 1st batch and 1st token (0) at 2nd batch
|
||||
self.assertListEqual(
|
||||
@ -212,14 +224,17 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
||||
vocab_size = 5
|
||||
batch_size = 2
|
||||
eos_token_id = 4
|
||||
cur_len = 4
|
||||
|
||||
input_ids = tf.constant([[0, 1, 3, 1], [0, 1, 0, 1]], dtype=tf.int32)
|
||||
self.assertEqual(cur_len, input_ids.shape[1])
|
||||
|
||||
bad_word_tokens = [[1], [4], [1, 0], [0, 1, 2], [1, 3, 1, 3]]
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
|
||||
no_bad_words_dist_proc = TFNoBadWordsLogitsProcessor(bad_words_ids=bad_word_tokens, eos_token_id=eos_token_id)
|
||||
|
||||
filtered_scores = no_bad_words_dist_proc(input_ids, tf.identity(scores))
|
||||
filtered_scores = no_bad_words_dist_proc(input_ids, tf.identity(scores), cur_len)
|
||||
|
||||
# batch 1: 1st, 2nd, and 4th (0, 1, 3) token are forbidden
|
||||
# batch 2: 1st, 2nd, and 3rd (0, 1, 2) token are forbidden
|
||||
@ -228,14 +243,65 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
||||
[[True, True, False, True, True], [True, True, True, False, True]],
|
||||
)
|
||||
|
||||
def test_forced_bos_token_logits_processor(self):
|
||||
vocab_size = 20
|
||||
batch_size = 4
|
||||
bos_token_id = 0
|
||||
|
||||
logits_processor = TFForcedBOSTokenLogitsProcessor(bos_token_id=bos_token_id)
|
||||
|
||||
# check that all scores are -inf except the bos_token_id score
|
||||
cur_len = 1
|
||||
input_ids = ids_tensor((batch_size, cur_len), vocab_size=20)
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
scores = logits_processor(input_ids, scores, cur_len)
|
||||
self.assertTrue(
|
||||
tf.math.reduce_all(tf.math.is_inf(scores[:, bos_token_id + 1 :]) & (scores[:, bos_token_id + 1 :] < 0))
|
||||
)
|
||||
self.assertListEqual(scores[:, bos_token_id].numpy().tolist(), 4 * [0]) # score for bos_token_id shold be zero
|
||||
|
||||
# check that bos_token_id is not forced if current length is greater than 1
|
||||
cur_len = 4
|
||||
input_ids = ids_tensor((batch_size, cur_len), vocab_size=20)
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
scores = logits_processor(input_ids, scores, cur_len)
|
||||
self.assertFalse(tf.math.reduce_any(tf.math.is_inf((scores))))
|
||||
|
||||
def test_forced_eos_token_logits_processor(self):
|
||||
vocab_size = 20
|
||||
batch_size = 4
|
||||
eos_token_id = 0
|
||||
max_length = 5
|
||||
|
||||
logits_processor = TFForcedEOSTokenLogitsProcessor(max_length=max_length, eos_token_id=eos_token_id)
|
||||
|
||||
# check that all scores are -inf except the eos_token_id when max_length-1 is reached
|
||||
cur_len = 4
|
||||
input_ids = ids_tensor((batch_size, cur_len), vocab_size=20)
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
scores = logits_processor(input_ids, scores, cur_len)
|
||||
self.assertTrue(
|
||||
tf.math.reduce_all(tf.math.is_inf(scores[:, eos_token_id + 1 :]) & (scores[:, eos_token_id + 1 :] < 0))
|
||||
)
|
||||
self.assertListEqual(
|
||||
scores[:, eos_token_id].numpy().tolist(), 4 * [0]
|
||||
) # score for eos_token_id should be zero
|
||||
|
||||
# check that eos_token_id is not forced if max_length-1 is not reached
|
||||
cur_len = 3
|
||||
input_ids = ids_tensor((batch_size, cur_len), vocab_size=20)
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
scores = logits_processor(input_ids, scores, cur_len)
|
||||
self.assertFalse(tf.math.reduce_any(tf.math.is_inf((scores))))
|
||||
|
||||
def test_processor_list(self):
|
||||
batch_size = 4
|
||||
sequence_length = 10
|
||||
cur_len = 10
|
||||
vocab_size = 15
|
||||
eos_token_id = 0
|
||||
|
||||
# dummy input_ids and scores
|
||||
input_ids = ids_tensor((batch_size, sequence_length), vocab_size)
|
||||
input_ids = ids_tensor((batch_size, cur_len), vocab_size)
|
||||
input_ids_comp = tf.identity(input_ids)
|
||||
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
@ -251,13 +317,13 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
||||
no_bad_words_dist_proc = TFNoBadWordsLogitsProcessor(bad_words_ids=[[1]], eos_token_id=eos_token_id)
|
||||
|
||||
# no processor list
|
||||
scores = min_dist_proc(input_ids, scores)
|
||||
scores = min_dist_proc(input_ids, scores, cur_len)
|
||||
scores = temp_dist_warp(input_ids, scores)
|
||||
scores = rep_penalty_proc(input_ids, scores)
|
||||
scores = rep_penalty_proc(input_ids, scores, cur_len)
|
||||
scores = top_k_warp(input_ids, scores)
|
||||
scores = top_p_warp(input_ids, scores)
|
||||
scores = no_repeat_proc(input_ids, scores)
|
||||
scores = no_bad_words_dist_proc(input_ids, scores)
|
||||
scores = no_repeat_proc(input_ids, scores, cur_len)
|
||||
scores = no_bad_words_dist_proc(input_ids, scores, cur_len)
|
||||
|
||||
# with processor list
|
||||
processor = TFLogitsProcessorList(
|
||||
@ -271,7 +337,7 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
||||
no_bad_words_dist_proc,
|
||||
]
|
||||
)
|
||||
scores_comp = processor(input_ids, scores_comp)
|
||||
scores_comp = processor(input_ids, scores_comp, cur_len=cur_len)
|
||||
|
||||
# remove inf
|
||||
scores = set_tensor_by_indices_to_value(scores, tf.math.is_inf(scores), -1e9)
|
||||
|
@ -536,7 +536,6 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
|
||||
"bad_words_ids": [tokenizer("is").input_ids, tokenizer("angry about").input_ids],
|
||||
"no_repeat_ngram_size": 2,
|
||||
"do_sample": False,
|
||||
"repetition_penalty": 1.3,
|
||||
"num_beams": 2,
|
||||
}
|
||||
|
||||
@ -544,8 +543,8 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
|
||||
|
||||
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||
expected_output_string = [
|
||||
"Today is a beautiful day and I hope you enjoy it.\nI am very happy to announce that",
|
||||
"Yesterday was the first time I've ever seen a game where you can play with",
|
||||
"Today is a beautiful day and a great day for all of us.\n\nI’m",
|
||||
"Yesterday was the first day of the year for the second time in a row,",
|
||||
]
|
||||
self.assertListEqual(output_strings, expected_output_string)
|
||||
|
||||
|
@ -508,7 +508,7 @@ class TFSpeech2TextModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
# if bos token id is not defined model needs input_ids, num_return_sequences = 1
|
||||
self._check_generated_ids(model.generate(input_features, do_sample=True, num_beams=2))
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
with self.assertRaises(ValueError):
|
||||
# generating more sequences than having beams leads is not possible
|
||||
model.generate(input_features, do_sample=False, num_return_sequences=3, num_beams=2)
|
||||
|
||||
|
@ -1179,7 +1179,7 @@ class TFModelTesterMixin:
|
||||
# num_return_sequences = 1
|
||||
self._check_generated_ids(model.generate(do_sample=True, max_length=5, num_beams=2))
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
with self.assertRaises(ValueError):
|
||||
# generating more sequences than having beams leads is not possible
|
||||
model.generate(input_ids, do_sample=False, num_return_sequences=3, num_beams=2)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user